Introduction
In this post we will discuss the concept of fold in functional programming: fold functions are used to reduce a data structure containing multiple values into a single one. Associated to the idea of fold are the concepts of
- recursion: via recursion, fold traverses the different elements of the data structure.
- summarisation: the data structure with all its elements is reduced to a single value
In the following example, the class Rectangle can be reduced to a single value representing its area. However, it is not ‘foldable’ in the sense that there is no recursion involved.
case class Rectangle(side1: Int, side2: Int) def area(rectangle: Rectangle): Int = rectangle.side1 * rectangle.side2 println(area(Rectangle(2,3))) //6
A typical use case of fold is to calculate the sum of the elements of a list: there is recursion and a summary value. Here’s 3 different implementations to illustrate how fold is similar to a loop or a recursive operation:
//iterating over the elements of the list with a loop def sumLoop(xs: List[Int]): Int = { var acc = 0 for(x <- xs){ acc += x } acc } //iterating over the elements of the list using recursion def sumRecursive(xs: List[Int]): Int = xs match { case Nil => 0 case x :: tail => x + sumRecursive(tail) } //using Scala API def sumFold(xs: List[Int]): Int = xs.reduce((acc, x) => acc + x)
Actually, Scala API implements foldLeft as a loop in TraversableOnce
def foldLeft[B](z: B)(op: (B, A) => B): B = { var result = z this foreach (x => result = op(result, x)) result }
Fold functions
Intuitively, based on the previous examples, fold is equivalent to a loop that applies a function ‘f’ to the elements of a collection (in our example, that function is the addition of numbers).
On the Scala API there are multitude of fold functions: fold, foldRight, foldLeft, reduce, reduceRight, reduceLeft, etc.
Left vs Right
In order to understand the purpose of all these variants, let’s start by doing a simple experiment. What is the result of this subtraction?
1 - 2 - 3 - 4
If your answer was -8, then we have being taught the same kind of Maths. The kind in which, when there are no parentheses, subtraction is done from left to right. This way we save ourselves the need to write
( ( (1 - 2) - 3) - 4)
So why is it necessary to specify the order in which subtractions are done? Well, because subtraction is not an associative operation and therefore the order in which it is done matters.
from left to right
List(1,2,3,4).reduceLeft((acc,x) => acc - x) = ( ( (1 - 2) - 3) - 4) = -8
from right to left
List(1,2,3,4).reduceRight((x, acc) => x - acc) = (1 - (2 - (3 - 4) ) ) = -2
no specific direction
((1 - 2) - (3 - 4)) = 0
Therefore
- left variants apply functions from left to right
- right variants apply functions from right to left
- non-directional variants apply functions without guaranteeing any specific order
Directional variants should be used when the function is not associative.
Reduce vs Fold
The main difference between the fold and reduce families is that fold functions take an extra parameter that is added to the collection. This way, fold functions can handle empty collections; reduce functions, on the other hand, cannot handle empty collections and will throw an exception.
from left to right
List(1,2,3,4).foldLeft(0)((acc,x) => acc - x) = ( ( ( (0 - 1) - 2) - 3) - 4) = -10
from right to left
List(1,2,3,4).foldRight(0)((x, acc) => x - acc) = (1 - (2 - (3 - (4 - 0) ) ) ) = -2
The previous statement about “fold functions taking an extra parameter that is added to the collection” needs to be clarified as that interpretation is only true when the function applied produces an element of the same type as the elements in the collection. Otherwise, the extra parameter is needed to be able to start applying the function.
Now it is clear that the previously defined functions, ‘sumLoop’ and ‘sumRecursive’, are implemented as ‘reduceLeft‘.
Climbing the abstraction ladder
Once we are familiar with the basics, we can start building more elaborated abstractions. Inspired by Cats, that has a type class called Foldable, we can create our own abstraction of Fold. The following trait represents things that can be folded from left to right.
trait FoldLeft[F[_]] { def foldLeft[A, B](xs: F[A], b: B, f: (B, A) => B): B }
Folding Lists
Here’s the companion object with an implementation of FoldLeft for Lists:
object FoldLeft{ def apply[F[_]: FoldLeft]: FoldLeft[F] = implicitly[FoldLeft[F]] implicit val foldLeftList: FoldLeft[List] = new FoldLeft[List] { override def foldLeft[A, B](xs: List[A], b: B, f: (B, A) => B): B = xs.foldLeft(b)(f) } }
Finally, we define an API to sum the elements of a foldable structure:
object FoldableAPI { def sum[F[_]: FoldLeft](xs: F[Int]): Int = { FoldLeft[F].foldLeft[Int, Int](xs, 0, (acc, x) => acc + x) } }
And voila, here’s the result:
object Main extends App { println(sum(List(1,2,3,4))) //10 }
Folding Trees
What if instead of List we want to use a different structure like a Tree?
sealed abstract class Tree[A] final case class Leaf[A](value: A) extends Tree[A] final case class Branch[A](value: A, left: Tree[A], right: Tree[A]) extends Tree[A] object Branch { //smart constructor: upcast Branch to Tree def branch[A](value: A, left: Tree[A], right: Tree[A]): Tree[A] = Branch(value, left, right) }
We need to provide an instance of FoldLeft for Tree:
object FoldLeft{ def apply[F[_]: FoldLeft]: FoldLeft[F] = implicitly[FoldLeft[F]] implicit val foldLeftList: FoldLeft[List] = new FoldLeft[List] { override def foldLeft[A, B](xs: List[A], b: B, f: (B, A) => B): B = xs.foldLeft(b)(f) } implicit val foldLefTree: FoldLeft[Tree] = new FoldLeft[Tree] { override def foldLeft[A, B](xs: Tree[A], b: B, f: (B, A) => B): B = xs match { case Leaf(v) => f(b,v) case Branch(v,l,r) => foldLeft(r, foldLeft(l, f(b,v), f), f) } } }
object Main extends App { println(sum(branch(1, Branch(1, Leaf(2), Leaf(3)), Leaf(8)))) //15 }
Because ‘sum’ is an operation associative and commutative, we are missing a few things here. Let’s extend our API to concatenate String elements of the foldable data structures:
object FoldableAPI { def sum[F[_]: FoldLeft](xs: F[Int]): Int = { FoldLeft[F].foldLeft[Int, Int](xs, 0, (acc, x) => acc + x) } def concatenate[F[_]: FoldLeft](xs: F[String]): String = { FoldLeft[F].foldLeft[String, String](xs, "", (acc, x) => acc + x) } }
And given that string concatenation is not commutative, now we’ll notice some interesting results. According to our definition of “foldLefTree”, the algorithm to fold a Tree is:
- take the initial node
- concatenate the value of the left branch on the right
- concatenate the value of the right branch on the right of the previous result
object Main extends App { println(concatenate(branch("1", Branch("1", Leaf("2"), Leaf("3")), Leaf("8")))) //"11238" }
Although correct, I would have expected something like:
- take the initial node
- concatenate the value of the left branch on the left
- concatenate the value of the right branch on the right of the previous result.
resulting in “21318”.
To be able to do this implementation, we need to use Fold instead of FoldLeft. The crucial difference here is that with Fold we are free to re-arrange the logic of our fold function in a ‘non-directional’ way. (That would not be possible with FoldLeft.foldLeftTree as the types would not match)
trait Fold[F[_]] { def fold[A](xs: F[A], b: A, f: (A, A) => A): A } object Fold { def apply[F[_]: Fold]: Fold[F] = implicitly[Fold[F]] implicit val foldTree: Fold[Tree] = new Fold[Tree] { override def fold[A](xs: Tree[A], b: A, f: (A, A) => A): A = xs match { case Leaf(v) => f(b,v) case Branch(v,l,r) => f(f(fold(l,b,f), v),fold(r,b,f)) } } } object FoldableAPI { def sum[F[_]: FoldLeft](xs: F[Int]): Int = { FoldLeft[F].foldLeft[Int, Int](xs, 0, (acc, x) => acc + x) } def concatenate[F[_]: FoldLeft](xs: F[String]): String = { FoldLeft[F].foldLeft[String, String](xs, "", (acc, x) => acc + x) } def concatenate2[F[_]: Fold](xs: F[String]): String = { Fold[F].fold[String](xs, "", (acc, x) => acc + x) } } object Main extends App { println(concatenate(branch("1", Branch("1", Leaf("2"), Leaf("3")), Leaf("8")))) //"11238" println(concatenate2(branch("1", Branch("1", Leaf("2"), Leaf("3")), Leaf("8")))) //"21318" }
With Cat’s Foldable, the result would be similar to our own FoldLeft as Foldable only allows ‘directional’ implementations of ‘fold’:
object CatsFoldable extends App { import cats._ import cats.implicits._ implicit val foldableTree = new Foldable[Tree] { override def foldLeft[A, B](fa: Tree[A], b: B)(f: (B, A) => B): B = fa match { case Leaf(v) => f(b,v) case Branch(v,l,r) => foldLeft(r, foldLeft(l, f(b,v))(f))(f) } override def foldRight[A, B](fa: Tree[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] = ??? } println(Foldable[Tree].fold(branch("1", Branch("1", Leaf("2"), Leaf("3")), Leaf("8")))) //"11238" }
The code used on this article can be found on FoldExamples.sc and Fold.scala.