Monads!

19 June 2020

I’m trying to be named least interesting blogger on the internet (I already have excellent supporting documentation in the form of my nearly-empty Apache access logs), so I thought I’d write something about monads.


So, you’ve decided to give functional programming a try. You’re using Haskell, or maybe Scala, Ocaml or Kotlin-Arrow but you’ve pinky not to stray from the path of pure FP.

Functional programming has a long history, and the first functional languages’ (that is to say, Lisp’s) defining characteristic was probably supporting closures. Closures are first-class functions which may encapsulate some local state. The simplest example is Lisp’s let-over-lambda pattern:

(define make-counter ()
  (let ((counter 0))
    (lambda (x)
      (set! counter (+ counter x))
      counter)))
      
(define my-counter (make-counter))

(my-counter 5)  ; => 5
(my-counter 3)  ; => 8
(my-counter 0)  ; => 8

As cool as this is, modern functional programming has a different focus from Lisp. Now that closures are mainstream (and made much less cool by immutability), the essential element of functional programming is purity (and arguably static typing, don’t kill me Rich Hickey).

Purity

Purity is the lack of side effects. A good concept that covers most of what is meant by purity is referential transparency, the idea that a variable should always be substitutable by its definition without changing the semantics of the program. For instance, here’s some referentially transparent code in Scala:

val x = 3 + 7
println(x)

We could rewrite this as

println(3 + 7)

without changing the program. But Scala is not generally referentially transparent! For instance, look at this code:

def sendEmail(): Boolean = ... // send an email, 
                               // return true on success

val success = sendEmail()
log.info(s"Email success: $success")
if (!success) throw new RuntimeException()

If we try to replace success with its definition as above, we get:

log.info(s"Email success: ${sendEmail()}")
if (!sendEmail()) throw new RuntimeException()

which now sends an email twice; the code is doing something different if we replace success with its definition, so it’s not referentially transparent. OK, nothing groundbreaking so far.

But if we insist on everything being referentially transparent, as Haskell does, we actually deprive ourselves of quite a few things:

  • IO breaks referential transparency because of side-effects.
  • Mutable state breaks referential transparency for any variable definition using that state: the definition captures the state at the time of definition, so any state changes between the variable definition and its use will preclude referential transparency.
  • Exceptions break referential transparency too: consider

    val x = throw new RuntimeException()
    try {
      x
    } catch {
      case e: Exception => println("hello")
    }
    

    This program crashes, but if we substitute x with its definition as we’ve done above, it just prints “hello”.

  • Most other forms of control flow can break referential transparency. If the language supports them as expressions (e.g. in Kotlin), early returns, continue and break, all break it for reasons similar to exceptions. So do continuations.

None of these things are a huge problem per se. For instance, take mutable state. Many functions can just be rewritten to avoid mutation as much as possible. When it is not possible, a common transformation is to create functions that take and return state objects, in addition to their real result. For instance, in Scala, one could transform

import scala.collection.mutable.ArrayBuffer

class BankAccount {
  var transactions: ArrayBuffer[(String, Int)] = ArrayBuffer()
  var balance: Int = 0
  
  def addTransaction(desc: String, amount: Int): Int = {
    transactions += ((desc, amount))
    balance += amount
    balance
  }
  def addMonthlyFees(): Int = {
    addTransaction("Monthly fees", -10)
  }
  def reset(): Unit = {
    transactions = ArrayBuffer()
    balance = 0
  }
}

into

case class BAState(
  transactions: Vector[(String, Int)] = Vector(),
  balance: Int = 0
)

object BankAccount {
  def addTransaction(st: BAState, 
                     desc: String,
                     amount: Int): (BAState, Int) = {
    val newTransactions = st.transactions :+ ((desc, amount))
    val newBalance = st.balance + amount
    (BAState(transactions=newTransactions, balance=newBalance),
     newBalance)
  }
  def addMonthlyFees(st: BAState): (BAState, Int) = {
    addTransaction(st, "Monthly fees", -10)
  }
  def reset(st: BAState): BAState = {
    BAState(transactions=Vector(), balance=0)
  }
}

Each of these new functions explicitly takes in a piece of immutable state, and returns an updated version of that state (which has to be a new state object, since our state objects are immutable), as well as its actual result. The resulting function is easier to test, to inspect, and generally to reason about.

State boilerplate

You can follow along with the code in this section on Scastie, an online Scala editor.

If you have a bunch of state, and functions that operate on immutable state as above, you’ll probably end up writing more boilerplate than if you stick to an imperative / mutable / not referentially transparent style. Note how our immutable addTransaction above had more work to do to manage the explicit passing of the state both in and out.

We also see a lot of boilerplate when calling:

val initialState = BAState()
val (withDeposit, balanceAfterFirstDeposit) = 
  addTransaction(initialState, "initial deposit", 100)
val (withFees, _) = addMonthlyFees(withDeposit)
val (withExpenses, balanceAfterLastDeposit) =
  addTransaction(withFees, "paycheck", 50)
return balanceAfterLastDeposit - balanceAfterFirstDeposit

In addition to being tedious, this boilerplate increases the risk of bugs by accidentally giving one of the functions the wrong version of the state.

Of course, when there’s boilerplate, functional programmers will try to create abstractions. Note how our functions that operate on BAState all take a BAState and return a pair of a new BAState and a result of an arbitrary type. We can create a new type, BF (for BAState function) , defined as

type BF[T] = BAState => (BAState, T)

that is to say, a BF[T] is a function which takes a BAState, and returns a new one and a result of type T.

Let’s revisit our BankAccount object with this definition:

object BankAccount {
  def addTransaction(desc: String, amount: Int): BF[Int] =
  { st: BAState =>
    val newTransactions = st.transactions :+ ((desc, amount))
    val newBalance = st.balance + amount
    (BAState(transactions=newTransactions, balance=newBalance),
     newBalance)
  }
  val addMonthlyFees: BF[Int] = 
    addTransaction("Monthly fees", -10)
  val reset: BF[Unit] = { st: BAState =>
    (BAState(transactions=Vector(), balance=0), ())
  }
}

Side note: Currying

Currying is a way to convert a function that takes a number of arguments into a sequence of functions that each take one argument. If I have a function multiply(a: Int, b: Int), I can rewrite it as a function that takes a first number, and returns a function that takes a second number and returns the product. For instance:

multiply(a: Int) = { (b: Int) => a * b }

val multiplyBy16: Int => Int = multiply(16)
val thirtyTwo = multiplyBy16(32)

// Or directly
val fortyTwo = multiply(6)(7)

We can generalize this by grouping some number of arguments at each step:

val addThreeCurryTwo(x: Int, y: Int) = { z: Int => x + y + z }
val ten = addThreeCurry(3, 4)(3)

We are using this form of currying to define addTransaction: the function addTransaction isn’t really a BF[Int] per se, but for any value of desc and amount, it gives us a BF[Int], i.e. a function from BAState to (BAState, Int).

Sequencing and composition

Ok, so we’ve saved a little bit of typing in the counter function signatures. We’ll come back to it later to see how we can save more, but for now let’s think about the call site.

If we have a list of functions that each perform a state transition on our bank account and produce a side result (i.e., a list of BF[T]), we can turn it into a single function that performs the total transition and collects all the side results in a list (i.e., a BF[List[T]]). This operation is normally called sequence:

def sequence[T](l: List[BF[T]]): BF[List[T]] = { st =>
  l match {
    case hd :: tl =>
      val (newState, result) = hd(st)
      val (finalState, results) = sequence(tl)(newState)
      (finalState, result :: results)
    case Nil  => (st, Nil)
  }
}

Don’t worry about the details of this code (we’ll see a simpler way to do something similar later), but let’s look at how we can use sequence. We can now do something like:

val initialState = BankAccountState()
val transform: BF[List[Int]] = sequence(List(
  addTransaction("initial deposit", 100),
  addMonthlyFees,
  addTransaction("paycheck", 50),
))
val (finalState, allResults) = transform(initialState)
return allResults(2) - allResults(1)

What is useful about this is that we can write the state-passing logic just once, in the sequence function, and we can then apply many transforms with very little syntactic boilerplate.

How should we think about BF?

This works, but it’s much less flexible than the version above, because we have to collect all results into a list, and it works poorly with functions returning a different result type (e.g. our reset function), but I hope you agree it’s much tidier and less error-prone! We will soon see much more flexible control structures for BF.

At this point I want to pause and propose an intuitive understanding of BF[T], which may seem like a strange way to consider things at first. The point is to consider an object of type BF[T] as an object of type T, except to actually get to the value, we need to pass a state and recieve a new one (which we call a BF effect).

A nice property of BF effects is that BF effect + BF effect = still just one BF effect. In particular, imagine we have:

  • a BF[T], aka a T which needs a BF effect to happen if we want to read the value
  • a BF[U], aka a U which needs a BF effect to happen if we want to read the value

Then we should be able to get a BF[(T, U)]: the types “add up” (or get tupled), but the BF effect doesn’t stack; there’s still only one in the end.

Ok, here’s how to actually do this transformation:

def product[T, U](ct: BF[T], cu: BF[U]): BF[(T, U)] = { initialSt =>
  val (tState, t) = ct(initialSt)
  val (uState, u) = cu(tState)
  (uState, (t, u))
}

Note that the order we apply effects in is up to us. We could define another function, rproduct, with exactly the same signature, but defined as:

def rproduct[T, U](ct: BF[T], cu: BF[U]): BF[(T, U)] = { initialSt =>
  val (uState, u) = cu(initialSt)
  val (tState, t) = ct(uState)
  (tState, (t, u))
}

In this variant, any state change needed to access the U value is performed before the change associated with the T value. Using rproduct instead of product on the same arguments could yield completely different results!

We can define quite a few other helpers with BF[_]. An easy one is map. From a BF[T] and a function T => U, we can get a BF[U]. This makes sense both if you think about the specifics of BF just being CounterState => (CounterState, T), and hopefully also if you think about the definition above: from a T that needs a BF effect to get to, and a way to turn T into U, we can get a U that needs a BF effect to get to.

Anyways, here’s map:

def map[T, U](ct: BF[T], f: T => U): BF[U] = { initialState =>
  val (finalState, t) = ct(initialState)
  (finalState, f(t))
}

flatMap

The last helper I want to point out is traditionally called flatMap. It takes a BF[T] and a function T => BF[U] and returns a BF[U].

Clearly, flatMap is more powerful than map. It’s easy to implement map when given flatMap, because we can turn any U into a BF[U] by just ignoring the state component:

def addBF[U](u: U): BF[U] =
  { st: BAState => (st, u) }
  
def map[T, U](bt: BF[T], f: T => U): BF[U] =
  flatMap(bt, { t => addBF(f(t)) })

On the other hand, there is no way to turn a BF[U] back into a U, so it’s not clear how to implement map in terms of flatMap (in fact, it’s not generally possible). To take a higher level view, the reason we are able to get a BF[U] instead of the BF[BF[U]] that we would have gotten using map is that, as we’ve discussed above, BF effects don’t stack.

Here’s flatMap:

def flatMap[T, U](ct: BF[T], f: T => BF[U]): BF[U] = { initialState =>
  val (tState, t) = ct(initialState)
  val (uState, u) = f(t)(tState)
  (uState, u)
}

flatMap may seem a bit arbitrary, but it’s actually very practically useful. We’ve seen above how you can use sequence to compose stateful operations, but it’s clunky because you have to collect the results in a list, and they all have to be the same type. We can use flatMap to do better!

val transformer: BF[Int] =
  flatMap(reset, { _: Unit =>
    flatMap(addTransaction("deposit", 100), { firstBalance: Int =>
      map(addMonthlyFees, { lastBalance: Int =>
        lastBalance - firstBalance })})})

Note a few things here:

  • We are able to mix functions that return different types, such as reset (which returns Unit) and addTransaction (which returns Boolean).
  • We can give names to the return values we care about, and we could add any non-state-dependent code we want in the lambda bodies, making this very flexible.
  • The state is passed from each BF[_] function to the next without having to do anything.

But there are some drawbacks too. The first is increased stack usage with the nested flatMap methods. Fortunately, this has no effect in tail-call-optimizing languages like Haskell, and in Scala, flatMaps that are practically used are optimized to evaluate in constant stack space.

Perhaps a bigger problem is that this is still not very nice to write. We’ve eliminated having to pass state manually, but we’ve gained runaway indentation and lots of nested lambdas and calls to flatMap.

Fortunately, Scala is kind enough to provide alternate syntax for flatMap. We can rewrite the transformer above as1This won’t directly work in our case, because flatMap is an independent function instead of being defined in BF instances, but it works fine for properly-defined monadic types.:

val transformer: BF[Boolean] = for {
  _ <- reset
  firstBalance <- addTransaction("deposit", 100)
  lastBalance <- addMonthlyFees
  result = lastBalance - firstBalance
} yield result

The rewrite method is not that obvious. Any line whose assignment is done with an arrow (<-) will be replaced by a flatMap whose first argument is the right-hand side of the arrow, and whose second argument is a lambda, whose single argument is the left-hand side of the arrow, and whose body is the rest of the for body. As an exception to this rule, the last arrow-assignment will become a map instead; this is because anything after the last arrow will not produce monadic values, just transformations on the existing encapsulated values.

Lines with =-assignments can be freely mixed in, and will be inserted into the body of the generated lambdas.

Let’s summarize what we’ve done:

  1. Passing the state around explicitly is nice for lots of reasons, but is very annoying to do in real code.
  2. We can use BF[T] as an abstraction for functions that take and return state and a result. This abstraction can be understood as “a BF[T] is a value of T gated by a BF effect”.
  3. BF[_] lets us define various interesting combinators, such as sequence, product, and map, to combine our stateful functions in predefined ways.
  4. flatMap, another BF[_] combinator, lets us recover almost all of the flexibility we had with manual state-passing. The only real constraint with flatMap is that the state objects are totally hidden from the programmer, and they are forced to flow from one stateful function to the next.
  5. There is special syntax and support for flatMap in both Scala and Haskell (which uses do instead of for).

The monad abstraction

You might have seen it coming: BF[_] is a monadic type. All a type M must do to be monadic is to have a flatMap function and a pure function, with signatures:

def flatMap[T, U](m: M[T], f: T => M[U]): M[U]
def pure[T](t: T): M[T]

and which obey the following laws:

flatMap(m, pure) == m
flatMap(pure(t), f) == f(t)
flatMap(flatMap(m, f), g) == flatMap(m, { t => flatMap(f(t), g) })

Don’t worry about the laws too much, they are stating properties which everybody expects to be true. The point is that once you have flatMap and pure, you can use all the other transformers we talked about above (e.g. sequence, product, map, and many, many more) on the type.

We’ve already seen a definition of flatMap for BF[_], and here’s pure:

def pure[T](t: T): BF[T] = { state =>
  (state, t)
}

An alternative definition for monads uses three functions, pure, map, and flatten, with signatures:

def map[T, U](m: M[T], f: T => U): M[U]
def flatten[T](m: M[M[T]]): M[T]
def pure[T](t: T): M[T]

This description makes it more obvious that the monadic interface requires effects to merge, which we can force with flatten. Both descriptions are equivalent: pure is obviously the same function in both, and we have:

def map[T, U](m: M[T], f: T => U): M[U] =
  flatMap(m, x => pure(f(x)))
def flatten[T](m: M[M[T]]): M[T] =
  flatMap(m, x => x)

def flatMap[T, U](m: M[T], f: T => M[U]): M[U] =
  flatten(map(m, f))

What are monads used for?

First, let’s look at what the paper that introduced monads to Haskell has to say:

Say I write an interpreter in a pure functional language.

To add error handling to it, I need to modify the result type to include error values, and at each recursive call to check for and handle errors appropriately. Had I used an impure language with exceptions, no such restructuring would be needed.

To add an execution count to it, I need to modify the the result type to include such a count, and modify each recursive call to pass around such counts appropriately. Had I used an impure language with a global variable that could be incremented, no such restructuring would be needed.

To add an output instruction to it, I need to modify the result type to include an output list, and to modify each recursive call to pass around this list appropriately. Had I used an impure language that performed output as a side effect, no such restructuring would be needed.

Or I could use a monad.

This paper shows how to use monads to structure an interpreter so that the changes mentioned above are simple to make. In each case, all that is required is to redefine the monad and to make a few local changes. This programming style regains some of the flexibility provided by various features of impure languages. It also may apply when there is no corresponding impure feature.

The technique applies not just to interpreters, but to a wide range of functional programs.

Wadler, Philip. "The essence of functional programming." Proceedings of the 19th ACM SIGPLAN-SIGACT symposium on Principles of programming languages. 1992.

Just like we’ve seen with BF, monads are in large part used with for or do comprehensions to make functional programming idioms easier to compose.

Contrary to what Haskellers might want you to believe, monads are generally dispensable: it always remains possible to rewrite code by manually passing state, or manually composing Either values. (This is not fully true in Haskell for the specific case of the IO monad, which is handled by the compiler).

All monadic types have one type parameter, and you can read M[T] as “a value of type T, but that can be read only if an M effect is run”. The additional constraint is that stacked effects must collapse, i.e. M[M[T]] must be reducible to M[T].2This is always true for a monad, because we can define flatten[T](t: M[M[T]]): M[T] = t.flatMap(x => x)

A difficulty with the monad abstraction is that its implementations seem much more varied than that of say, List. This is true, and maybe Monad should be compared to something like AutoCloseable, which describes a useful aspect of an object, but can be applied to completely different types. Contrary to List, and like AutoCloseable, the specific type of monad that is being used matters 90% of the time. The reason it makes sense to have Monad as an abstraction in the first place is mainly to share the convenient syntax and transformer functions.

All that monadic types really have in common is the idea of wrapping an underlying type T in some effects that must be dealt with before accessing T. The nature of these effects vary widely, but they are required to merge, i.e. M[M[T]] must be reducible to M[T]. Again, the way these effects merge in practice is highly monad-dependent.

And just like you don’t write new List implementations every day, you’re not too likely to write your own custom monads (although composing existing ones is common, and I’ll talk about it more below).

In real code, you are likely to be exposed to these monadic types:

  • State[S, T] = (S) => (S, T). While very similar to BF[T], there’s a possible source of confusion: a monadic type must have a single type parameter, but State has two. However, any type of the form MyState[T] = State[StateCarrier, T] for a given, fixed StateCarrier is monadic, and this is what we mean when we say State is monadic. We’ve discussed the interpretation of State at length.
  • Either[E, T]. The same remark applies as for State, this is monadic for a fixed E. Read as “a value of type T, but it may have been replaced by an error of type E”. This is commonly used to replace exceptions in pure FP settings.
  • Reader[R, T]. Monadic for a fixed R. Much like a one-sided state, where the functions are automatically passed an R but do not get a chance to update it. “A value of type T, but an R needs to be provided to get access to it.”
  • Writer[W, T]. Monadic for a fixed W. The other side of a one-sided state; functions have the opportunity to return a W in addition to their normal result type. This is generally used for logging.
  • IO[T]. This is particularly used in Haskell, where it represents a value with I/O side effects. In Haskell, the entire program must return an IO[()] value, and the interpreter takes care of performing the I/O side effects on the value that the program has returned.

    IO also exists in Scala Cats, where it represents a suspended task; it doesn’t need any special compiler support because Scala is actually impure and is fine with I/O being done anywhere.

  • Future[T]: Futures are also monadic; they are “a T, except we might have to wait to get to it”. It’s might be easier to see that futures have map and flatten than to see how they have flatMap, but as we’ve seen, both are equivalent.
  • List[T] and Option[T]. These are less frequently used in practice. They are “a value of type T, but there might be none (or several)”

More about Writer

We’ve looked at a hand-rolled State monad in the first part. Let’s now look at Cats’ Writer monad in more detail to help illustrate some of the points made above.

A bit like State, Writer is designed to allow light logging or message reporting, without passing a logging object around explicitly.

A Writer[W, T] is effectively a pair of two values, a log of type W, and a value of type T. As we’ve said before, Writer has two type parameters, so it can’t be a monad out of the box. We need to pick a W to create a new type that will be monadic. There’s the added constraint that the W must be a monoid. You can look at the Wikipedia page for more information, but basically it just means a data structure for which we can create an empty element and append values (immutably). We’ll pick Vector[String].

Next, we can define

import cats.data.Writer  
import cats.implicits._

type Logged[T] = Writer[Vector[String], T]

Because of the way Writer is designed, our type Logged[T] is monadic. If we want a function to perform logging, we will make it return a Logged[T] value instead of a T. For instance, we can define

def double(x: Int): Logged[Int] = {  
  val messages = if (x > 10) Vector(s"warning, $x is a big number")
                 else Vector.empty  
  Writer(messages, x * 2)  
}

Note how we return a Writer containing both our log messages and the actual result. We could also have written this using for-comprehensions in a more imperative style:

def double(x: Int): Logged[Int] = for {
  _ <- Writer.tell(if (x > 10) Vector(s"warning, $x is a big number")
                   else Vector.empty)
} yield x * 2

Writer.tell is a helper that constructs a new Writer with the given messages, and a Unit (empty) payload. We are of course free to define our own functions that create Writers; for instance, we could create a variant of tell that accepts a nullable string.

Here is yet another variant:

def double(x: Int): Logged[Int] = for {
  _ <- if (x > 10) Writer.tell(Vector(s"warning, $x is a big number"))
       else ().pure[Logged]
} yield x * 2

().pure constructs a new Writer value with no messages, and a value of () (which is the empty value in Scala).

What I want to stress by showing various examples is that using a monad is mostly transparent insofar as the values wrapped by the monad go (you should only need to use pure and flatMap or for comprehensions), but you will need to use the monad’s unique functions (here, tell or the Writer constructor) in order to manipulate the associated effects.

So far, we haven’t done much more than create a pair of a vector of messages and some data. We’ve used a fancy syntax, but we could have done it manually. Why do we need to care about it being a monad? It’s for much the same reason as we needed it for BF[T]: it’s a pain to deal with just pairs, because when you turn a function from an Int => Int to an Int => Logged[Int] (or Int => (Vector[String], Int)), it no longer composes with other Int => Int functions.

But we can leverage any monadic transformer, such as map, sequence, or for comprehensions, to make this much less painful.

val resultWithLog: Logged[Int] = for {
  x <- double(2)
  y = x + 10
  z <- double(y)
  _ <- Writer.tell(Vector("all done"))
} yield z
println(s"Result: ${transformer.run._2}, " +
        s"messages: ${transformer.run._1}")

Gives us

Result: 28, messages: Vector(warning, 14 is a big number, all done)

Monad transformers

So you really like Logged, and you’ve rewritten a bunch of functions in your app that need to report data on what they did, so they return various Logged[T] values. But now some of the functions also want to use some State! You could make the functions return State[S, Logged[T]] or Logged[State[S, T]] values, but neither of these types are monadic; in general, composing two monads does not give you a single big monad.

This is where monad transformers come into play. As I’ve said before, it’s common to have to combine existing monads. In short, this is done using monad transformers. For instance, WriterT lets us turn an already-monadic type into a new monadic type which also supports logging. Note that WriterT is not a monad itself. Here’s how we would create a new monad combining State and Logged:

type MyState[T] = State[StateCarrier, T]
type StateLog[T] = WriterT[MyState, Vector[String], T]

You can use the resulting StateLog as any other monads, and it gives you both Writer methods (like tell) and State methods (like inspect and modify). While it’s outside the scope of this post, there’s a lot more to say about monad transformers, and they’re commonly used in monad-using programs.

There are other paradigms to combine monads, including MTL / final-tagless and Free monads. Both are definitely worth looking into for larger programs using monads.

Conclusion

I hope the examples in this post have given you a better idea of why the monad abstraction is useful, and how it’s used in practice.

If you decide to use monads in your next Scala project, be aware that I haven’t touched on type classes at all, even though this is the way the monad interface is applied to monadic types in Scala. This comes with its own complexities, which I would recommend looking into.

  1. This won’t directly work in our case, because flatMap is an independent function instead of being defined in BF instances, but it works fine for properly-defined monadic types. 

  2. This is always true for a monad, because we can define flatten[T](t: M[M[T]]): M[T] = t.flatMap(x => x) 

Comments