Tail call optimization in erlang
1 2 |
fact(1) -> 1; fact(N) -> N * fact(N - 1). |
You’ve all seen the classic recursive factorial definition. Problem is, it’s not really useable. 50000 factorial, anyone? The problem is it needs to create a new stack frame for each recursive call, very quickly blowing out your memory usage. Let’s look at a classic erlang structure, a message processing loop:
1 2 3 4 5 |
loop() -> receive hello -> io:format("hello") loop(). end. |
That looks mighty recursive also – one would be inclined to think that saying hello a couple of thousand times would quickly chew through memory. Happily, this is not the case! The reason is tail call optimization.
As you can see, the above hello program is really just a loop. Note that when we call loop()
, there’s no reason to maintain the stack for the current call, because there is no more processing to be done. It just needs to pass on the return value. The erlang compiler recognises this, and so can optimize the above code by doing just that – throwing away the stack (or transforming it into a loop, whichever you prefer).
With the factorial example, optimization cannot be done because each call needs to wait for the return value of fact(N-1)
to multiply it by N
– extra processing that depends on the call’s stack.
Tail call optimization can only be done when the recursive call is the last operation in the function.
With this knowledge, we can rewrite our factorial function to include an accumulator parameter, allowing us to take advantage of the optimization.
1 2 3 |
fact(N) -> fact(N, 1). fact(1, T) -> T; fact(N, T) -> fact(N - 1, T * N). |
Or since we recognise that you can redo this with a loop, you could always just write it that way yourself.
1 |
fact(N) -> lists:foldl(fun(X, T) -> X * T end, 1, lists:seq(1, N)). |
I haven’t used erlang enough to make a call as to which is nicer. Probably the first one. I’m a ruby guy at heart, so for old time’s sake here’s a version you can use in ruby, which I think is quite pretty (be warned ruby doesn’t do tail call optimization).
1 2 3 |
def fact(n) (1..n).inject(1) {|t, n| t * n} end |