So I'm trying to define my own state monad by extending the Monad
trait in scalaz. I know I'm reinventing the wheel, but I'm trying to learn more about scala and scalaz. My code is as follows:
package pure
import scalaz._, Scalaz._
object StateUtil {
sealed trait State[S, A] { self =>
def unit[A](a : A) : State[S, A] = StateM(s => (a, s))
def runState(s : S) : (A, S)
def flatMap[B](f : A => State[S, B]) : State[S, B] = {
val stateFun = (s : S) => {
val (v, ss) = self.runState(s)
f(v).runState(ss)
}
StateM(stateFun)
}
def map[B](f : A => B) : State[S, B] = flatMap(f andThen (unit _))
}
case class StateM[S, A](run : S => (A, S)) extends State[S, A] {
def runState(s : S) : (A, S) = run(s)
}
class StateMonad[S]() extends Monad[({type St[A] = State[S, A]})#St] {
def point[A](a : => A) : State[S, A] = StateM(s => (a, s))
def bind[A, B](prev : State[S, A])(f : A => State[S, B]) : State[S, B] = prev flatMap f
def apply[A](a : => A) : State[S, A] = point(a)
}
def put[S](s : S) : State[S, Unit] = StateM(_ => ((), s))
def get[S]: State[S, S] = StateM(s => (s, s))
}
I'm trying to achieve the same behavior as that of the haskell function called stackyStack
(the code is in the comment below). The problem is that the get
method returns something of type State[S, S]
, and I cannot use the scalaz's >>=
operator. But I've no idea how to make that work.
And I'm not even sure if this is how you define your own monads using scalaz. And if I want to replicate the do
syntax in scala, what else am I missing?
object main {
import pure.StateUtil._
/*
* stackyStack :: State Stack ()
stackyStack = do
stackNow <- get
if stackNow == [1,2,3]
then put [8,3,1]
else put [9,2,1]
*/
type Stack = List[Int]
def stacky : State[Stack, Unit] = {
// won't compile, because get returns something of type State[S, S],
// and the bind operator is not a memeber of the State trait I defined
get >>=
}
}