4

I'm using a state transformer to randomly sample a dataset at every point of a 2D recursive walk, which outputs a list of 2D grids of samples that together succeed a condition. I'd like to pull from the results lazily, but my approach instead exhausts the whole dataset at every point before I can pull the first result.

To be concrete, consider this program:

import Control.Monad ( sequence, liftM2 )
import Data.Functor.Identity
import Control.Monad.State.Lazy ( StateT(..), State(..), runState )

walk :: Int -> Int -> [State Int [Int]]
walk _ 0 = [return [0]]
walk 0 _ = [return [0]]
walk x y =
  let st :: [State Int Int]
      st = [StateT (\s -> Identity (s, s + 1)), undefined]
      unst :: [State Int Int] -- degenerate state tf
      unst = [return 1, undefined]
  in map (\m_z -> do
      z <- m_z
      fmap concat $ sequence [
          liftM2 (zipWith (\x y -> x + y + z)) a b -- for 1D: map (+z) <$> a
          | a <- walk x (y - 1) -- depth
          , b <- walk (x - 1) y -- breadth -- comment out for 1D
        ]
    ) st -- vs. unst

main :: IO ()
main = do
  std <- getStdGen
  putStrLn $ show $ head $ fst $ (`runState` 0) $ head $ walk 2 2

The program walks the rectangular grid from (x, y) to (0, 0) and sums all the results, including the value of one of the lists of State monads: either the non-trivial transformers st that read and advance their state, or the trivial transformers unst. Of interest is whether the algorithm explores past the heads of st and unst.

In the code as presented, it throws undefined. I chalked this up to a misdesign of my order of chaining the transformations, and in particular, a problem with the state handling, as using unst instead (i.e. decoupling the result from state transitions) does produce a result. However, I then found that a 1D recursion also preserves laziness even with the state transformer (remove the breadth step b <- walk... and swap the liftM2 block for fmap).

If we trace (show (x, y)), we also see that it does walk the whole grid before triggering:

$ cabal run
Build profile: -w ghc-8.6.5 -O1
...
(2,2)
(2,1)
(1,2)
(1,1)
(1,1)
sandbox: Prelude.undefined

I suspect that my use of sequence is at fault here, but as the choice of monad and the dimensionality of the walk affect its success, I can't say broadly that sequenceing the transformations is the source of strictness by itself.

What's causing the difference in strictness between 1D and 2D recursion here, and how can I achieve the laziness I want?

concat
  • 3,107
  • 16
  • 30

2 Answers2

2

Consider the following simplified example:

import Control.Monad.State.Lazy

st :: [State Int Int]
st = [state (\s -> (s, s + 1)), undefined]

action1d = do
  a <- sequence st
  return $ map (2*) a

action2d = do
  a <- sequence st
  b <- sequence st
  return $ zipWith (+) a b

main :: IO ()
main = do
  print $ head $ evalState action1d 0
  print $ head $ evalState action2d 0

Here, in both the 1D and 2D calculations, the head of the result depends explicitly only on the heads of the inputs (just head a for the 1D action and both head a and head b for the 2D action). However, in the 2D calculation, there's an implicit dependency of b (even just its head) on the current state, and that state depends on the evaluation of the entirety of a, not just its head.

You have a similar dependency in your example, though it's obscured by the use of lists of state actions.

Let's say we wanted to run the action walk22_head = head $ walk 2 2 manually and inspect the first integer in the resulting list:

main = print $ head $ evalState walk22_head

Writing the elements of the state action list st explicitly:

st1, st2 :: State Int Int
st1 = state (\s -> (s, s+1))
st2 = undefined

we can write walk22_head as:

walk22_head = do
  z <- st1
  a <- walk21_head
  b <- walk12_head
  return $ zipWith (\x y -> x + y + z) a b

Note that this depends only on the defined state action st1 and the heads of walk 2 1 and walk 1 2. Those heads, in turn, can be written:

walk21_head = do
  z <- st1
  a <- return [0] -- walk20_head
  b <- walk11_head
  return $ zipWith (\x y -> x + y + z) a b

walk12_head = do
  z <- st1
  a <- walk11_head
  b <- return [0] -- walk02_head
  return $ zipWith (\x y -> x + y + z) a b

Again, these depend only on the defined state action st1 and the head of walk 1 1.

Now, let's try to write down a definition of walk11_head:

walk11_head = do
  z <- st1
  a <- return [0]
  b <- return [0]
  return $ zipWith (\x y -> x + y + z) a b

This depends only on the defined state action st1, so with these definitions in place, if we run main, we get a defined answer:

> main
10

But these definitions aren't accurate! In each of walk 1 2 and walk 2 1, the head action is a sequence of actions, starting with the action that invokes walk11_head, but continuing with actions based on walk11_tail. So, more accurate definitions would be:

walk21_head = do
  z <- st1
  a <- return [0] -- walk20_head
  b <- walk11_head
  _ <- walk11_tail  -- side effect of the sequennce
  return $ zipWith (\x y -> x + y + z) a b

walk12_head = do
  z <- st1
  a <- walk11_head
  b <- return [0] -- walk02_head
  _ <- walk11_tail  -- side effect of the sequence
  return $ zipWith (\x y -> x + y + z) a b

with:

walk11_tail = do
  z <- undefined
  a <- return [0]
  b <- return [0]
  return [zipWith (\x y -> x + y + z) a b]

With these definitions in place, there's no problem running walk12_head and walk21_head in isolation:

> head $ evalState walk12_head 0
1
> head $ evalState walk21_head 0
1

The state side effects here are not needed to calculate the answer and so never invoked. But, it's not possible to run them both in sequence:

> head $ evalState (walk12_head >> walk21_head) 0
*** Exception: Prelude.undefined
CallStack (from HasCallStack):
  error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
  undefined, called at Lazy2D_2.hs:41:8 in main:Main

Therefore, trying to run main fails for the same reason:

> main
*** Exception: Prelude.undefined
CallStack (from HasCallStack):
  error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
  undefined, called at Lazy2D_2.hs:41:8 in main:Main

because, in calculating walk22_head, even the very beginning of walk21_head's calculation depends on the state side effect walk11_tail initiated by walk12_head.

Your original walk definition behaves the same way as these mockups:

> head $ evalState (head $ walk 1 2) 0
1
> head $ evalState (head $ walk 2 1) 0
1
> head $ evalState (head (walk 1 2) >> head (walk 2 1)) 0
*** Exception: Prelude.undefined
CallStack (from HasCallStack):
  error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
  undefined, called at Lazy2D_0.hs:15:49 in main:Main
> head $ evalState (head (walk 2 2)) 0
*** Exception: Prelude.undefined
CallStack (from HasCallStack):
  error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
  undefined, called at Lazy2D_0.hs:15:49 in main:Main

It's hard to say how to fix this. Your toy example was excellent for the purposes of illustrating the problem, but it's not clear how the state is used in your "real" problem and if head $ walk 2 1 really has a state dependency on the sequence of walk 1 1 actions induced by head $ walk 1 2.

K. A. Buhr
  • 45,621
  • 3
  • 45
  • 71
  • Excellent points. I've added an answer to explain the dependencies in my "real" problem if you're interested. Thank you for bringing clarity to this! – concat Mar 12 '20 at 06:48
1

The accepted answer by K.A. Buhr is right: while getting the head of one step in each direction is fine (try walk with either x < 2 or y < 2) the combination of the implicit >>= in liftM2, the sequence in the value of a and the state dependency in the value of b makes b depend on all side effects of a. As he also pointed out, a working solution depends on what dependencies are actually wanted.

I'll share a solution for my particular case: each walk call depends on the state of the caller at least, and perhaps some other states, based on a pre-order traversal of the grid and alternatives in st. In addition, as the question suggests, I want to try to make a full result before testing any unneeded alternatives in st. This is a little difficult to explain visually, but here's the best I could do: the left shows the variable number of st alternatives at each coordinate (which is what I have in my actual use case) and the right shows a [rather messy] map of the desired dependency order of the state: we see it traverses x-y first in a 3D DFS, with "x" as depth (fastest axis), "y" as breadth (middle axis), then finally alternatives as the slowest axis (shown in dashed lines with open circles).

enter image description here

The central issue in the original implementation came from sequencing lists of state transitions to accommodate the non-recursive return type. Let's replace the list type altogether with a type that's recursive in the monad parameter, so the caller can better control the dependency order:

data ML m a = MCons a (MML m a) | MNil -- recursive monadic list
newtype MML m a = MML (m (ML m a)) -- base case wrapper

An example of [1, 2]:

MCons 1 (MML (return (MCons 2 (MML (return MNil)))))

Functor and Monoid behaviors are used often, so here's the relevant implementations:

instance Functor m => Functor (ML m) where
  fmap f (MCons a m) = MCons (f a) (MML $ (fmap f) <$> coerce m)
  fmap _ MNil = MNil

instance Monad m => Semigroup (MML m a) where
  (MML l) <> (MML r) = MML $ l >>= mapper where
    mapper (MCons la lm) = return $ MCons la (lm <> (MML r))
    mapper MNil = r

instance Monad m => Monoid (MML m a) where
  mempty = MML (pure MNil)

There are two critical operations: combining steps in two different axes, and combining lists from different alternatives at the same coordinate. Respectively:

  1. Based on the diagram, we want to get a single full result from the x step first, then a full result from the y step. Each step returns a list of results from all combinations of viable alternatives from inner coordinates, so we take a Cartesian product over both lists, also biased in one direction (in this case y fastest). First we define a "concatenation" that applies a base case wrapper MML at the end of a bare list ML:

    nest :: Functor m => MML m a -> ML m a -> ML m a
    nest ma (MCons a mb) = MCons a (MML $ nest ma <$> coerce mb)
    

    then a Cartesian product:

    prodML :: Monad m => (a -> a -> a) -> ML m a -> ML m a -> ML m a
    prodML f x (MCons ya ym) = (MML $ prodML f x <$> coerce ym) `nest` ((f ya) <$> x)
    prodML _ MNil _ = MNil
    
  2. We want to smash the lists from different alternatives into one list and we don't care that this introduces dependencies between alternatives. This is where we use mconcat from the Monoid instance.

All in all, it looks like this:

walk :: Int -> Int -> MML (State Int) Int
-- base cases
walk _ 0 = MML $ return $ MCons 1 (MML $ return MNil)
walk 0 _ = walk 0 0

walk x y =
  let st :: [State Int Int]
      st = [StateT (\s -> Identity (s, s + 1)), undefined]
      xstep = coerce $ walk (x-1) y
      ystep = coerce $ walk x (y-1)
     -- point 2: smash lists with mconcat
  in mconcat $ map (\mz -> MML $ do
      z <- mz
                              -- point 1: product over results
      liftM2 ((fmap (z+) .) . prodML (+)) xstep ystep
    ) st

headML (MCons a _) = a
headML _ = undefined

main :: IO ()
main = putStrLn $ show $ headML $ fst $ (`runState` 0) $ (\(MML m) -> m) $ walk 2 2

Note the result have changed with the semantics. It doesn't matter to me since my goal only needed to pull random numbers from state, and whatever dependency order is needed can be controlled with the right shepherding of list elements into the final result.

(I'll also warn that without memoization or attention to strictness, this implementation is very inefficient for large x and y.)

concat
  • 3,107
  • 16
  • 30