-- from https://gist.github.com/ppetr/3552390
import Control.Monad
import Control.Monad.Cont
import Control.Monad.State
import Control.Monad.Trans
data Pause m a = Pause { step :: m (Either (Pause m a) a) }
instance MonadTrans Pause where
  lift k = Pause (fmap Right k)
mutate :: (Monad m) => m () -> Pause m ()
mutate = lift
suspend :: (Monad m) => Pause m a -> Pause m a
suspend = Pause . return . Left
yield :: (Monad m) => Pause m ()
yield = suspend (return ())
instance (Monad m) => Functor (Pause m) where
  fmap = liftM
instance (Monad m) => Applicative (Pause m) where
  pure = return
  (<*>) = ap
instance (Monad m) => Monad (Pause m) where
    return x = lift (return x) -- Pause (return (Right x))
    (Pause s) >>= f = Pause $ s >>= \ x -> case x of
      Right y -> step (f y)
      Left p  -> return (Left (p >>= f))
test1 :: Int -> Pause (State Int) Int
test1 y = do
  x <- lift get
  lift $ put (x * 2)
  yield
  x' <- lift get
  lift $ put (x' + 5)
  yield
  return (y + x)
debug :: Show s => s -> Pause (State s) a -> IO (s, a)
debug s p = case runState (step p) s of
    (Left next, s')     ->  putStrLn ("Paused with " ++ show s') >> debug s' next
    (Right r, s')       ->  return (s', r)
main :: IO ()
main = debug 1000 (test1 1) >>= putStrLn . ("Finished with " ++) . show