12

I have been trying to understand the State Monad. Not so much how it is used, though that is not always easy to find, either. But every discussion I find of the State Monad has basically the same information and there is always something I don't understand.

Take this post, for example. In it the author has the following:

case class State[S, A](run: S => (A, S)) {
...
  def flatMap[B](f: A => State[S, B]): State[S, B] =
    State(s => {
      val (a, t) = run(s)
      f(a) run t
    })
...
}

I can see that the types line up correctly. However, I don't understand the second run at all.

Perhaps I am looking at the whole purpose of this monad incorrectly. I got the impression from the HaskellWiki that the State monad was kind of like a state-machine with the run allowing for transitions (though, in this case, the state-machine doesn't really have fixed state transitions like most state machines). If that is the case then in the above code (a, t) would represent a single transition. The application of f would represent a modification of that value and State (generating a new State object). That leaves me completely confused as to what the second run is all about. It would appear to be a second 'transition'. But that doesn't make any sense to me.

I can see that calling run on the resulting State object produces a new (A, S) pair which, of course, is required for the types to line up. But I don't really see what this is supposed to be doing.

So, what is really going on here? What is the concept being modeled here?

Edit: 12/22/2015

So, it appears I am not expressing my issue very well. Let me try this.

In the same blog post we see the following code for map:

def map[B](f: A => B): State[S, B] =
  State(s => {
    val (a, t) = run(s)
    (f(a), t)
  })

Obviously there is only a single call to run here.

The model I have been trying to reconcile is that a call to run moves the state we are keeping forward by a single state-change. This seems to be the case in map. However, in flatMap we have two calls to run. If my model was correct that would result in 'skipping over' a state change.

To make use of the example @Filppo provided below, the first call to run would result in returning (1, List(2,3,4,5)) and the second would result in (2, List(3,4,5)), effectively skipping over the first one. Since, in his example, this was followed immediately by a call to map, this would have resulted in (Map(a->2, b->3), List(4,5)).

Apparently that is not what is happening. So my whole model is incorrect. What is the correct way to reason about this?

2nd Edit: 12/22/2015

I just tried doing what I said in the REPL. And my instincts were correct which leaves me even more confused.

scala> val v = State(head[Int]).flatMap { a => State(head[Int]) }
v: State[List[Int],Int] = State(<function1>

scala> v.run(List(1,2,3,4,5))
res2: (Int, List[Int]) = (2,List(3, 4, 5))

So, this implementation of flatMap does skip over a state. Yet when I run @Filippo's example I get the same answer he does. What is really happening here?

melston
  • 2,198
  • 22
  • 39
  • 1
    BTW Michael Pilquist made a great presentation about `State Monad`, there is also a video online: https://www.youtube.com/watch?v=Jg3Uv_YWJqI – Filippo Vitale Dec 22 '15 at 06:32
  • 1
    I just watched part of that video. Unfortunately, he did the same kind of 'hand-waving' about the second run as I have seen elsewhere. – melston Dec 22 '15 at 07:04

3 Answers3

11

To understand the "second run" let's analyse it "backwards".

The signature def flatMap[B](f: A => State[S, B]): State[S, B] suggests that we need to run a function f and return its result.

To execute function f we need to give it an A. Where do we get one?
Well, we have run that can give us A out of S, so we need an S.

Because of that we do: s => val (a, t) = run(s) .... We read it as "given an S execute the run function which produces us A and a new S. And this is our "first" run.

Now we have an A and we can execute f. That's what we wanted and f(a) gives us a new State[S, B]. If we do that then we have a function which takes S and returns Stats[S, B]:

(s: S) => 
   val (a, t) = run(s)
   f(a) //State[S, B]

But function S => State[S, B] isn't what we want to return! We want to return just State[S, B].

How do we do that? We can wrap this function into State:

State(s => ... f(a))

But it doesn't work because State takes S => (B, S), not S => State[B, S]. So we need to get (B, S) out of State[B, S].
We do it by just calling its run method and providing it with the state we just produced on the previous step! And it is our "second" run.

So as a result we have the following transformation performed by a flatMap:

s =>                   // when a state is provided
  val (a, t) = run(s)  // produce an `A` and a new state value
  val resState = f(a)  // produce a new `State[S, B]`
  resState.run(t)      // return `(S, B)`

This gives us S => (S, B) and we just wrap it with the State constructor.

Another way of looking at these "two runs" is:
first - we transform the state ourselves with "our" run function
second - we pass that transformed state to the function f and let it do its own transformation.

So we kind of "chaining" state transformations one after another. And that's exactly what monads do: they provide us with the ability to schedule computation sequentially.

Alexey Raga
  • 7,457
  • 1
  • 31
  • 40
  • 1
    Thanks. I went through a similar reasoning and came to the same conclusion. But that still doesn't explain what the second `run` *does* (other than make the types work out right). To spring off of your last comment why do we chain two transformations into a single mapping operation? How can we justify that? – melston Dec 22 '15 at 06:35
  • 1
    What it does? It runs the "second" state using the "intermediate" state provided. When the computation is executed, you want the state value to be transformed by "your" state monad and by the second one that is constructed by the function. This is exactly what happens. You return a computation that runs "this" state transformation and then "provided" state transformation. It needs to be run in the end of the day. The State that is returned from the `flatMap` just executes these two one-by-one. – Alexey Raga Dec 22 '15 at 08:54
  • take a look at my edit. Hopefully that will better explain what I am missing. – melston Dec 22 '15 at 15:10
  • 1
    I don't see why you say it skips anything. First you create a state with the `head` which, when evaluated, would give you `(1, List(2,3,4,5))`. Then you flatmap it. You chose to ignore the value `1` and further transform the state with the `head` function again. Which gives you `(2, List(3,4,5))` I don't see what is skept and where. – Alexey Raga Dec 23 '15 at 04:29
  • You are probably confused in understanding that `State(..)` doesn't do anything. It is just a function (it actually just wraps a function). All `flatMap` does is creates a function that will run the current state (1st `run`), then the stateful function you pass as an argument (the 2nd `run`) and returns will result. But it is still a function. You need to give it an initial `S` to trigger the computation. – Alexey Raga Dec 23 '15 at 04:37
  • The problem is more conceptual. `flatMap` is often described as `map` followed by `flatten` (whatever that means in a given context). However, as it is implemented in the State monad it appears to be `map` followed by `map` followed by a kind of `flatten`. When you say it 'runs the current state' that implies (to me) that it increments to the 'next' state. The next `run` then increments to the one following that - skipping the intermediate state. – melston Dec 23 '15 at 06:48
  • 3
    A `map` operation would run the state once, get the result and wrap it into `State`. A `flatMap` is given with a function that results in yet another `State`. If you just `map` this function then you will have your `State[State[S, B]]` that you will have to flatten. How would you do it? Well, you would have to execute the "inner" state to take the value from it, wouldn't you? That's your second `run` and that's your `flatten`. – Alexey Raga Dec 24 '15 at 05:57
4

The state monad boils down to this function from one state to another state (plus A):

type StatefulComputation[S, +A] = S => (A, S)

The implementation mentioned by Tony in that blog post "capture" that function into run of the case class:

case class State[S, A](run: S => (A, S))

The flatmap implementation to bind a state to another state is calling 2 different runs:

    // the `run` on the actual `state`
    val (a: A, nextState: S) = run(s)

    // the `run` on the bound `state`
    f(a).run(nextState)

EDIT Example of flatmap between 2 State

Considering a function that simply call .head to a List to get A, and .tail for the next state S

// stateful computation: `S => (A, S)` where `S` is `List[A]`
def head[A](xs: List[A]): (A, List[A]) = (xs.head, xs.tail)

A simple binding of 2 State(head[Int]):

// flatmap example
val result = for {
  a <- State(head[Int])
  b <- State(head[Int])
} yield Map('a' -> a,
            'b' -> b)

The expect behaviour of the for-comprehension is to "extract" the first element of a list into a and the second one in b. The resulting state S would be the remaining tail of the run list:

scala> result.run(List(1, 2, 3, 4, 5))
(Map(a -> 1, b -> 2),List(3, 4, 5))

How? Calling the "stateful computation" head[Int] that is in run on some state s:

s => run(s)

That gives the head (A) and the tail (B) of the list. Now we need to pass the tail to the next State(head[Int])

f(a).run(t)

Where f is in the flatmap signature:

def flatMap[B](f: A => State[S, B]): State[S, B]

Maybe to better understand what is f in this example, we should de-sugar the for-comprehension to:

val result = State(head[Int]).flatMap {
  a => State(head[Int]).map {
    b => Map('a' -> a, 'b' -> b)
  }
}

With f(a) we pass a into the function and with run(t) we pass the modified state.

Filippo Vitale
  • 7,597
  • 3
  • 58
  • 64
  • I still don't understand when you say the second run is on the 'bound' state. What does that mean? – melston Dec 22 '15 at 06:12
  • I mean is the `run` called on the second `State(head[Int])` – Filippo Vitale Dec 22 '15 at 06:24
  • But that is the exact issue I am trying to get clear. Why do we need the second `run` at all? What problem does it solve? What is the concept behind it? – melston Dec 22 '15 at 06:26
  • Because the `state monad` is just that function `S => (A, S)`, within `flatmap` we need to pass the result state after our run to the next state. BTW the second `run` is not running anything until a `.run` from "outside" get invoked. The flatmap implementation is just binding them. – Filippo Vitale Dec 22 '15 at 06:31
  • I realize it is a deferred operation. But when it is finally applied we seem to get two 'transformations' for a single operation. I just don't understand how we can justify this. – melston Dec 22 '15 at 06:36
  • 1
    It is exactly what you described: two transformations in a single operation. That is what any `flatMap` would do. You must return the result of the second transformation, and to perform it you need a result of the first one. So you execute the first, then you execute the second, then you return the result. That's what monads do: they represent computations in sequence. – Alexey Raga Dec 24 '15 at 05:53
4

I have accepted @AlexyRaga's answer to my question. I think @Filippo's answer was very good as well and, in fact, gave me some additional food for thought. Thanks to both of you.

I think the conceptual difficulty I was having was really mostly to do with 'what does the run method 'mean'. That is, what is its purpose and result. I was looking at it as a 'transition' function (from one state to the next). And, after a fashion, that is what it does. However, it doesn't transition from a given (this) state to the next state. Instead, it takes an initial State and returns the (this) state's value and a new 'current' state (not the next state in the state-transition sequence).

That is why the flatMap method is implemented the way it is. When you generate a new State then you need the current value/state pair from it based on the passed-in initial state which can then be wrapped in a new State object as a function. You are not really transitioning to a new state. Just re-wrapping the generated state in a new State object.

I was too steeped in traditional state machines to see what was going on here.

Thank, again, everyone.

melston
  • 2,198
  • 22
  • 39