module Amb where

-- A non-determinism monad implemented on top of a delimited continuation monad.

import Control.Applicative
import Control.Monad
import Control.Monad.Trans

import Data.Maybe

newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r }

type AmbT r m a = ContT (Maybe r) m a

instance (Monad m) => Monad (ContT r m) where
  return a = ContT $ \k -> k a
  m >>= f  = ContT $ \k -> runContT m (\a -> runContT (f a) k)

instance MonadTrans (ContT r) where
  lift m = ContT $ \k -> m >>= k

reset :: Monad m => ContT a m a -> ContT r m a
reset = lift . flip runContT return

shift :: Monad m => ((a -> ContT s m r) -> ContT r m r) -> ContT r m a
shift e = ContT $ \k -> e (lift . k) `runContT` return

none :: Monad m => AmbT r m a
none = shift $ \_ -> return Nothing

choose :: Monad m => AmbT r m a -> AmbT r m a -> AmbT r m a
choose m m' = shift $ \k -> m >>= k >>= maybe (m' >>= k) (return . Just)

amb :: Monad m => [a] -> AmbT r m a
amb = foldr (choose . return) none

reify :: Monad m => AmbT a m a -> ContT r m (Maybe a)
reify = reset . liftM Just

runAmbT :: Monad m => AmbT a m a -> m (Maybe a)
runAmbT = flip runContT return . reify

example :: Monad m => AmbT r m (Integer, Integer)
example = do x <- amb [1,2,3]
             y <- amb [4,5,6]
             if x * y == 8
               then return (x,y)
               else amb []

-- factor :: Integer -> Amb r (Integer, Integer)
-- factor a = do x <- amb [2..]
--               y <- amb [2..]
--               if x * y == a
--                 then return (x,y)
--                 else amb []
