{-# LANGUAGE DeriveFunctor #-}
import Debug.Trace (trace)
import Control.Monad ((>=>))
data Tree a = Node (Tree a) (Tree a) | Leaf a deriving Functor
instance Applicative Tree where
  pure = Leaf
  Leaf f <*> Leaf a = Leaf $ f a
  Leaf f <*> Node a b =  Node (f <$> a) (f <$> b)
  Node f g <*> leaf@(Leaf x) = Node (f <*> leaf) (g <*> leaf)
  Node f g <*> Node a b = Node (f <*> a) (g <*> b)
instance Monad Tree where
  return = pure
  (>>=) = (<-/)
(<-/) :: Tree a -> (a -> Tree b) -> Tree b
(Leaf a) <-/ f = f a
Node l r <-/ f = trace "<-/" (Node (l <-/ f) (r <-/ f))
instance Show a => Show (Tree a) where
  show (Leaf a) = "Leaf " ++ show a
  show (Node l r) = "Node (" ++ show l ++ ") (" ++ show r ++ ")"
x :: Tree Int
x =
  Node
    (Node
      (Leaf 9)
      (Node (Leaf 8) (Leaf 7))
    )
    (Leaf 2)
f x = Leaf $ x * 5
g x = Leaf $ x + 1
main = do
  -- time: |x| + |(x >>= f)| steps
  print $ (x >>= f) >>= g
  trace (replicate 10 '-') (return ())
  -- time: |x >>= f| steps
  print $ x >>= (\x -> f x >>= g)
  trace (replicate 10 '-') (return ())
  print $ x >>= (f >=> g)