Scala recursive functions

One of the fundamental tools of the functional programmer toolbox is the concept of recursion. It can be defined briefly as calling a function from inside itself. Recursive functions exist in many languages, even outside the domain of FP. To name one of the oldest, the C language allows recursive calls. Recursive functions are a cleaner and more elegant way to express problems without using loops. The only caveat is never forget to check for a termination condition to avoid endless looping. Let’s see how the factorial of a number can be calculated in Scala with a naif recursive call:

scala> def f(x: Int): Int = if (x == 1) 1 else x * f(x - 1)
f: (x: Int)Int

scala> f(2)
res0: Int = 2

scala> f(3)
res1: Int = 6

scala> f(4)
res2: Int = 24

Here the function just checks if its input is 1. In this case, the constant value of 1 is returned. Otherwise it returns the exact definition of the factorial, which says that f(x) is equal to x * f(x - 1). The elegance of Scala is exactly in the ability to connect math and programming. Note that the f() function declares its return type as Int. That’s not surprising on its own, but deserve noting, because usually Scala can guess the return type of a function on its own, and Scala programmers are all but apt to press one key stroke more than required. So why the return type here? Because recursive functions are explicitly required to declare the return type, to allow the compiler to perform its type checking.

Now, why I’ve called naif my implementation of f()? Well, it suffers from a fundamental problem that is crucial to understand when using recursive functions: overflowing the stack! If we take a closer look on how the function works by simulating it, we’ll note that the function does not reduce to a product of numbers like 4 * 3 * 2 * 1, but it have to push another context in the stack for every call:

f(4) => if (4 == 1) 1 else 4 * f(4 - 1)

now 4 is different from 1 and the call evaluates to 4 * f(4 - 1). But the value of the right part of the product can be determined only by calling f(4 - 1) and later multiplying it by 4. The context of f(4) must be left in the stack and a new for f(4 - 1) must be created. Then the situation repeats and f(3) is 3 * f(3 - 1) which must be evaluated after calling f(3 - 1). f(4) so creates 4 contexts in the stack, one per each call. What would happen then when f() is called with 50000? Quite predictably it throws a java.lang.StackOverflowError exception. How can we solve this?

The compiler here is our best friend. It happens that it’s able to guess if no additional context must be pushed in the stack, thus turning a pure recursive call into a tail recursive call which is internally implemented as a classical loop. To allow this bit of magic, we have to resolve all the values before calling the function from itself. This means that the n * f(n - 1) must become something different, that packs the value of n inside the call to f(), together with the other argument. To do this we use an accumulator variable, slightly changing the signature of f() to:

scala> def f(x: Int, acc: Int): Int = if (x == 1) acc else f(x - 1, acc * x)

This new version of f() does not left the product unresolved before calling itself. The product is in fact an argument of f() itself. This way we can compute f(50000) without incurring in a stack overflow exception. Well, to be honest, Int must be replaced by BigInt if we want to have a number different from 0 back from the function:

scala> def f(x: BigInt, acc: BigInt): BigInt = if (x == 1) acc else f(x - 1, acc * x)

scala> f(50000, 1)
res59: BigInt = 3347320509597144836915476094071486477912773223810454807730100321990168022144365641

This is the right way! If we only could hide that 1 from the call, we would avoid exposing internal implementation details. So we define the function as an iterator and later wrap it in another:

scala> def fIter(x: BigInt, acc: BigInt): BigInt = if (x == 1) acc else f(x - 1, acc * x)

scala> def f(x: BigInt): BigInt = fIter(x, 1)

But still fIter() hangs around like a directly-callable function. This is still suboptimal. Let’s hide it inside f():

import scala.annotation.tailrec

def f(x: BigInt): BigInt = {
  def fIter(x: BigInt, acc: BigInt): BigInt = 
    if (x == 1) acc else fIter(x - 1, acc * x)

  fIter(x, 1) 

Tail recursion gets its name from the fact that the recursive call happens as the last action of the calling function. Remember: the last action, like the tail is the last part of an animal. And by using the @tailrec annotation we ensure that tail recursion has be detected by the compiler which otherwise would abort compilation with an error. Now we’re done!