Scala recursive functions

One of the fundamental tools of the functional programmer toolbox is the concept of recursion. It can be defined briefly as calling a function from inside itself. Recursive functions exist in many languages, even outside the domain of FP. To name one of the oldest, the C language allows recursive calls. Recursive functions are a cleaner and more elegant way to express problems without using loops. The only caveat is never forget to check for a termination condition to avoid endless looping. Let’s see how the factorial of a number can be calculated in Scala with a naif recursive call:

scala> def f(x: Int): Int = if (x == 1) 1 else x * f(x - 1)
f: (x: Int)Int

scala> f(2)
res0: Int = 2

scala> f(3)
res1: Int = 6

scala> f(4)
res2: Int = 24

Here the function just checks if its input is 1. In this case, the constant value of 1 is returned. Otherwise it returns the exact definition of the factorial, which says that f(x) is equal to x * f(x - 1). The elegance of Scala is exactly in the ability to connect math and programming. Note that the f() function declares its return type as Int. That’s not surprising on its own, but deserve noting, because usually Scala can guess the return type of a function on its own, and Scala programmers are all but apt to press one key stroke more than required. So why the return type here? Because recursive functions are explicitly required to declare the return type, to allow the compiler to perform its type checking.

Now, why I’ve called naif my implementation of f()? Well, it suffers from a fundamental problem that is crucial to understand when using recursive functions: overflowing the stack! If we take a closer look on how the function works by simulating it, we’ll note that the function does not reduce to a product of numbers like 4 * 3 * 2 * 1, but it have to push another context in the stack for every call:

f(4) => if (4 == 1) 1 else 4 * f(4 - 1)

now 4 is different from 1 and the call evaluates to 4 * f(4 - 1). But the value of the right part of the product can be determined only by calling f(4 - 1) and later multiplying it by 4. The context of f(4) must be left in the stack and a new for f(4 - 1) must be created. Then the situation repeats and f(3) is 3 * f(3 - 1) which must be evaluated after calling f(3 - 1). f(4) so creates 4 contexts in the stack, one per each call. What would happen then when f() is called with 50000? Quite predictably it throws a java.lang.StackOverflowError exception. How can we solve this?

The compiler here is our best friend. It happens that it’s able to guess if no additional context must be pushed in the stack, thus turning a pure recursive call into a tail recursive call which is internally implemented as a classical loop. To allow this bit of magic, we have to resolve all the values before calling the function from itself. This means that the n * f(n - 1) must become something different, that packs the value of n inside the call to f(), together with the other argument. To do this we use an accumulator variable, slightly changing the signature of f() to:

scala> def f(x: Int, acc: Int): Int = if (x == 1) acc else f(x - 1, acc * x)

This new version of f() does not left the product unresolved before calling itself. The product is in fact an argument of f() itself. This way we can compute f(50000) without incurring in a stack overflow exception. Well, to be honest, Int must be replaced by BigInt if we want to have a number different from 0 back from the function:

scala> def f(x: BigInt, acc: BigInt): BigInt = if (x == 1) acc else f(x - 1, acc * x)

scala> f(50000, 1)
res59: BigInt = 3347320509597144836915476094071486477912773223810454807730100321990168022144365641
69738123107191693087984804381902082998936163847430666937426305728453637840383257562821233599872682
44078235972356040853854441373383753568565536371168327405166076155165921406156075461294201790567479
66549862924222002254155351071815980161547645181061667497021799653747497254113933819163882350063030
76442568748572713946510819098749096434862685892298078700310310089628611545539799116129406523273969
71497211031261142860733793509687837355811830609551728906603833592532851635961730885279811957399495
29945030635444247849264102899006955963488352990055767655092917547592078804480762256241516513045904
63180685174067663600123295564540657242251754734281831210291957155937874236411171945138385930380064
1313297631250...

This is the right way! If we only could hide that 1 from the call, we would avoid exposing internal implementation details. So we define the function as an iterator and later wrap it in another:

scala> def fIter(x: BigInt, acc: BigInt): BigInt = if (x == 1) acc else f(x - 1, acc * x)

scala> def f(x: BigInt): BigInt = fIter(x, 1)

But still fIter() hangs around like a directly-callable function. This is still suboptimal. Let’s hide it inside f():

import scala.annotation.tailrec

def f(x: BigInt): BigInt = {
  @tailrec
  def fIter(x: BigInt, acc: BigInt): BigInt = 
    if (x == 1) acc else fIter(x - 1, acc * x)

  fIter(x, 1) 
}

Tail recursion gets its name from the fact that the recursive call happens as the last action of the calling function. Remember: the last action, like the tail is the last part of an animal. And by using the @tailrec annotation we ensure that tail recursion has be detected by the compiler which otherwise would abort compilation with an error. Now we’re done!