Scala notes

Monads, for comprehension, flatMaps



The for comprehension is a syntax shortcut to combine flatMap and map in a way that's easy to read and reason about.

Let's simplify things a bit and assume that every class that provides both aforementioned methods can be called a monad and we'll use the symbol M[A] to mean a monad with an inner type A.


Some commonly seen monads

  • List[String] where
    • M[_]: List[_]
    • A: String
  • Option[Int] where
    • M[_]: Option[_]
    • A: Int
  • Future[String => Boolean] where
    • M[_]: Future[_]
    • A: String => Boolean

map and flatMap

Defined in a generic monad M[A]

/* applies a transformation of the monad "content" mantaining the 
  * monad "external shape"  
  * i.e. a List remains a List and an Option remains an Option 
  * but the inner type changes
  def map(f: A => B): M[B] 

 /* applies a transformation of the monad "content" by composing
  * this monad with an operation resulting in another monad instance 
  * of the same type
  def flatMap(f: A => M[B]): M[B]


  val list = List("neo", "smith", "trinity")

  //converts each character of the string to its corresponding code
  val f: String => List[Int] = s => 

  list map f
  >> List(List(110, 101, 111), List(115, 109, 105, 116, 104), List(116, 114, 105, 110, 105, 116, 121))

  list flatMap f
  >> List(110, 101, 111, 115, 109, 105, 116, 104, 116, 114, 105, 110, 105, 116, 121)

for expression

1. Each line in the expression using the <- symbol is translated to a flatMap call, except for the last line which is translated to a concluding map call, where the "bound symbol" on the left-hand side is passed as the parameter to the argument function (what we previously called f: A => M[B]):

// The following ...
for {
  bound <- list
  out <- f(bound)
} yield out

// ... is translated by the Scala compiler as ...
list.flatMap { bound =>
  f(bound).map { out =>

// ... which can be simplified as ...
list.flatMap { bound =>

// ... which is just another way of writing:
list flatMap f

2. A for-expression with only one <- is converted to a map call with the expression passed as argument:

// The following ...
for {
  bound <- list
} yield f(bound)

// ... is translated by the Scala compiler as ... { bound =>

// ... which is just another way of writing:
list map f

Now to the point

As you can see, the map operation preserves the "shape" of the original monad, so the same happens for the yield expression: a List remains a List with the content transformed by the operation in the yield

On the other hand each binding line in the for is just a composition of successive monads, which must be "flattened" in order to maintain a single "external shape".

Suppose for a moment that each internal binding was translated to a map call, but the right-hand was the same A => M[B] function, you would end up with a M[M[B]] for each line in the comprehension.

The intent of the whole for syntax is to easily "flatten" the concatenation of successive monadic operations (i.e. operations that "lift" a value in a "monadic shape": A => M[B]), with the addition of a final map operation that possibly performs a concluding transformation.

I hope this explains the logic behind the choice of translation, which is applied in a mechanical way, that is: n flatMap nested calls concluded by a single map call.

A contrived illustrative example

Meant to show the expressiveness of the for syntax
case class Customer(value: Int)
case class Consultant(portfolio: List[Customer])
case class Branch(consultants: List[Consultant])
case class Company(branches: List[Branch])

def getCompanyValue(company: Company): Int = {

  val valuesList = for {
    branch     <- company.branches
    consultant <- branch.consultants
    customer   <- consultant.portfolio
  } yield (customer.value)

  valueList reduce (_ + _)

As already said, the shape of the monad is mantained through the comprehension, so we start with a List in company.branches, and must end with a List.

The inner type instead changes and is determined by the yield expression: which is customer.value: Int

valueList should be a List[Int]