Examples of the Cont Monad

Albert Y. C. Lai, trebla [at] vex [dot] net

Actually we stack up other monad transformers on it for interesting things such as state. We do it like StateT s (Cont y), although we get similar behaviour with ContT y (State s) because instance lifting code is written to commute them.

> import Control.Monad.Cont
> import Control.Monad.State.Strict
> import Control.Monad.Reader
> import Control.Monad

Basic Principle

callCC :: ((a -> m b) -> m a) -> m a

Or, 3rd-rank polymorphic implementation

callCC :: ((forall b. a -> m b) -> m a) -> m a)

Typical usage:

do
  ...
  x <- callCC (\c -> .... c x1 ... return x0)
  ...

In this code, c :: a -> m b and x0,x1,x :: a

Think of c x1 as an action for early exit. If execution hits c x1, we exit early and x becomes x1. Otherwise, normal exit means x becomes x0. Note that c x1 :: m b with unconstrained b, so it can blend into its context. (Clearly, it doesn't really return a value of type b; execution jumps elsewhere and this is just type checking.)

Exiting a Loop

In this example, starting from initial state n, the loop increases it until the state becomes 5, then exits and returns that value. Note that c s :: m () to satisfy the type of when.

> loop :: (Monad m) => m a -> m b
> loop p = fix (p >>)

> fun :: Int -> Int
> fun n = runCont (evalStateT p n) id
>     where
>       p = do
>         { callCC $ \c -> loop $ do
>                               { modify (+ 1)
>                               ; s <- get
>                               ; when (s == 5) (c s)
>                               }
>         }

Setjmp

This example is more advanced. Can we "leak" c to the outer scope and use it elsewhere? Yes.

c still needs to be applied to some argument, let's call that x for now. (We will solve for x soon.) It's more convenient to return c x, and the user just needs to bind it to l and execute it later. When c x is executed, execution jumps back to its callCC origin, and x is returned and bound to the user's l, and it should be c x again so that the whole thing is repeatable. Therefore we have x = c x = fix c.

Typing: l, fix c :: m b, c :: m b -> m b. In this example b = Int by blending into the if-then-else.

> setjmp :: (MonadCont m) => m (m b)
> setjmp = callCC (\c -> return (fix c))

> jumpy :: Int -> Int
> jumpy n = runCont (evalStateT p n) id
>     where
>       p = do
>         { l <- setjmp
>         ; modify (+ 1)
>         ; s <- get
>         ; if s == 5 then return s else l
>         }

Exception

Exception throwing and handling is a close cousin of callCC. Here we implement it. The action to be protected needs to know how to throw an exception; here we assume it takes a thrower as a parameter, and it calls the provided thrower when it wants to. Inside callCC, which provides c, we can define the thrower to run a handler and then use c to exit early; now we can run the action with this thrower.

We use this to take square roots repeatedly of the state until it is close enough to 1; for states 0 or negative, we throw an exception, and the handler returns 0 or NaN respectively.

> catchC :: (MonadCont m) =>
>           ((e -> m b) -> m a)  -- action, takes a thrower parameter
>        -> (e -> m a)           -- handler
>        -> m a
> catchC action handler = callCC (\c -> action (\e -> handler e >>= c))

> data Bad = Zero | Neg deriving Show

> catchme :: Double -> Double
> catchme n = runCont (evalStateT p n) id
>     where
>       p = do
>         { catchC q handler
>         }
>       q throw = do
>         { s <- get
>         ; when (s < 0) (throw Neg)
>         ; when (s <= 0) (throw Zero)
>         ; l <- setjmp
>         ; t <- get
>         ; if abs(t - 1) < 0.01 then return t else modify sqrt >> l
>         }
>       handler Zero = return 0
>       handler Neg = return (sqrt (-1))

It is bothersome to mandate every protected action to take a thrower parameter. Now we tag on a ReaderT layer to make throwers implicit. As a bonus we also get to install a "top level" handler.

Doing this introduces an infinite type. (The monad is a MonadReader of the thrower, and the thrower mentions that monad again.) A lightweight way of untying this is to be specific about our monad stack (RSC below) and newtyping the thrower. There are more advanced solutions.

> type RSC e s y = ReaderT (Thrower e s y) (StateT s (Cont y))
> newtype Thrower e s y = Thrower (e -> RSC e s y ())

The throw command throwI asks for the thrower from the environment and executes it.

The thrower returns (). The throw command returns an arbitrary b to blend into its caller context. This gap can be bridged by appending a meaningless polymorphic action after the thrower. (Alternative solution: the Thrower type is Thrower (forall b. e -> RSC e s y b), and callCC needs to be 2nd-rank too.)

> throwI :: e -> RSC e s y b
> throwI e = do
>   { Thrower thrower <- ask
>   ; thrower e
>   ; undefined
>   }

The catch command catchI constructs the new thrower from the handler and uses it as the new environment for running the action. The thrower calls the handler and then jumps outside, as in the previous example. But in addition the handler needs to be arranged to run under the old environment so that it can "re-throw" exceptions. (If we don't code this up, the handler is run under the new environment (because it's called by the action under the new environment), and re-throwing causes looping.) This is seamlessly done with "local", which not only shadows the old environment, but also lets us map from the old to the new, so we can actually say we are constructing the new thrower from the old.

> catchI :: RSC e s y a
>        -> (e -> RSC e s y a)
>        -> RSC e s y a
> catchI action handler =
>     callCC (\c -> local (mkt c) action)
>     where mkt c r = Thrower (\e -> with r (handler e >>= c))
>           with r m = local (const r) m

This example usage takes square roots of the state until it's close to 1. For 0 or negative states, we throw an exception. Like the previous example, in the 0 case we return 0. Unlike the previous example, in the negative case we re-throw the exception to demonstrate that re-throwing works its way to the outer handler.

> exceptional :: Double -> Double
> exceptional n = runCont (evalStateT (runReaderT p topthrow) n) id
>     where
>       topthrow = Thrower (\e -> error ("unhandled exception " ++ show e))
>       p = do
>         { catchI q handler
>         }
>       q = do
>         { s <- get
>         ; when (s < 0) (throwI Neg)
>         ; when (s <= 0) (throwI Zero)
>         ; l <- setjmp
>         ; t <- get
>         ; if abs(t - 1) < 0.01 then return t else modify sqrt >> l
>         }
>       handler Zero = return 0
>       handler Neg = throwI Neg

Generator and Yield

callCC and mutable references together can implement yield-style generators. I use two mutable references: inside remembers the point in the body of the generator for resuming or starting, outside remembers the point in the caller for yielding. So, when the body yields to the caller, use callCC $ \ki -> ..., store ki in inside, then jump to where outside says. Dually, when the caller resumes the body, use callCC $ \ko -> ..., store ko in outside, then jump to where inside says. The following code sketch and comments may show this idea better:

 1  body = do
 2    ...
 3    ko <- read outside
 4    callCC $ \ki -> do
 5      write inside ki    -- inside stores line 7
 6      ko ()              -- jump to line 15
 7    ...
 8    body

 9  caller = do
10    ...
11    ki <- read inside
12    callCC $ \ko -> do
13      write outside ko   -- outside stores line 15
14      ki ()              -- jump to line 7
15    ...
16    caller

There are still a few loose ends and an extension:

These are implemented below.

{-| Yield-style generators. -}
module Yield where

import Control.Monad.Cont
import Data.IORef

{-| 
Type of return values of generators. @More@ means the generator yields. @End@
means the generator finishes. See below for examples.
-}
data Dot a b = More a | End b deriving Show

{-|
Create a yield-style generator from a body. Example:

> g <- mkgen (\yield c0 -> do
>                c1 <- yield a0
>                c2 <- yield a1
>                return b
>            )

Then we can use:

> dot0 <- g c0  -- dot0 = More a0
> dot1 <- g c1  -- dot1 = More a1
> dot2 <- g c2  -- dot2 = End b

Further calls to @g@ return @End b@ too.

Although usually @m=n@, i.e., @g@ and the body are in the same monad as the
@mkgen@ call, technically they can be different. The @mkgen@ call
is in m with MonadIO for allocating IORef. The body and @g@ are in n with
MonadIO and MonadCont for using IORef and callCC.
-}
mkgen :: (MonadIO m, MonadIO n, MonadCont n) =>
         ((a -> n c) -> c -> n b) -> m (c -> n (Dot a b))
mkgen body = do
  inside <- liftIO (newIORef undefined)
  outside <- liftIO (newIORef undefined)
  let yield y = do
        ko <- liftIO (readIORef outside)
        callCC (\ki -> do
                   liftIO (writeIORef inside ki)
                   ko (More y)
               )
      next x = do
        ki <- liftIO (readIORef inside)
        callCC (\ko -> do
                   liftIO (writeIORef outside ko)
                   ki x
               )
      start x = do
        e <- body yield x
        liftIO (writeIORef inside (\_ -> return (End e)))
        ko <- liftIO (readIORef outside)
        ko (End e)
        undefined
  liftIO (writeIORef inside start)
  return next

Here is an example usage. When the caller passes True to the body, the body yields a number; otherwise, the body finishes and returns a string. The caller keeps passing True until it gets a number at least 5; then it passes False.

import Control.Monad.Cont
import Yield

body yield b = bodyloop b 0 where
  bodyloop False n = do
    liftIO (putStrLn "body receives False")
    return (replicate n 'x')
  bodyloop True n = do
    liftIO (putStrLn "body receives True")
    b <- yield n
    bodyloop b $! (n+1)

cmain :: ContT r IO ()
cmain = do
  g <- mkgen body
  let cmainloop b = do
        d <- g b
        case d of
          End s -> do liftIO (putStrLn ("caller receives " ++ s))
                      return ()
          More n -> do liftIO (putStrLn ("caller receives " ++ show n))
                       cmainloop (n < 5)
  cmainloop True

main = runContT cmain return

I have more Haskell Notes and Examples