{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{- |
Module      : $Header$
Description : Class connecting mutable variables and monads they exists in.
Copyright   : (c) Maciej Piechotka
License     : MIT

Stability   : none
Portability : portable

Reference is class which generalizes references and monads they exists in. It means that IORef, STRef and others can be accessed by common interface.
-}
module Data.Reference
  (
    Reference(..),
    AtomicReference(..),
    SyncReference(..),
  )
where
import Control.Arrow
import Control.Concurrent.MVar
import Control.Concurrent.STM
import Control.Concurrent.STM.TVar
import Control.Monad
import Control.Monad.ST
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Class
import Control.Monad.Trans.Error
import Control.Monad.Trans.Identity
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.RWS.Lazy as RWSTL
import Control.Monad.Trans.RWS.Strict as RWSTS
import Control.Monad.Trans.Reader
import Control.Monad.Trans.State.Lazy as SL
import Control.Monad.Trans.State.Strict as StS
import Control.Monad.Trans.Writer.Lazy as WL
import Control.Monad.Trans.Writer.Strict as WS
import Data.IORef
import Data.Maybe
import Data.Monoid
import Data.STRef

-- | Class connecting mutable variable and monad it exists in.
class Monad m => Reference r m where
    -- | Create new reference.
    newRef :: a -- ^ An initial value
           -> m (r a)
    -- | Reads a reference.
    readRef :: r a -- ^ Reference
            -> m a
    -- | Write to reference.
    writeRef :: r a -- ^ Reference
             -> a -- ^ New value
             -> m ()
    -- | Modify the reference. Default implementation is provided but it MUST
    -- be overloaded if the reference is atomic to provide an atomic write.
    --
    -- If the returned value is Nothing value does not have to be updated.
    modifyRef :: r a -- ^ Reference
              -> (a -> m (Maybe a, b)) -- ^ Computation
              -> m b -- ^ Result of computation
    modifyRef r f = readRef r >>= f >>= \(a, b) -> case a of
        Nothing -> return b
        Just a' -> writeRef r a' >> return b

-- | Marks an atomic reference i.e. one that allows to change value atomically
-- by pure function.
class Reference r m => AtomicReference r m where
    -- | Modify the reference by pure function.
    modifyAndSwap :: r a -- ^ Reference
                  -> (a -> a) -- ^ Function to modify atomically
                  -> m a -- Returns old value
    modifyAndSwap r f = modifyRef r (\x -> return $! (Just $! f x, x))

-- | Atomic reference. It specifies that 'modifyRef' is atomic.
class AtomicReference r m => SyncReference r m where {}

-- | It is 'modifyRef' with added constraint for 'SyncReference'
atomicModifyRef :: SyncReference r m
                => r a -- ^ Reference
                -> (a -> m (Maybe a, b)) -- ^ Computation
                -> m b -- ^ Result of computation
atomicModifyRef = modifyRef

instance Reference IORef IO where
    newRef = newIORef
    readRef = readIORef
    writeRef = writeIORef

instance Reference (STRef s) (ST s) where
    newRef = newSTRef
    readRef = readSTRef
    writeRef = writeSTRef

instance Reference MVar IO where
    newRef = newMVar
    readRef = readMVar
    writeRef = putMVar
    modifyRef r f = modifyMVar r $ \a -> return . first (fromMaybe a) =<< f a

instance AtomicReference MVar IO

instance SyncReference MVar IO

instance Reference TVar IO where
    newRef = newTVarIO
#ifdef OLD_STM
    readRef = atomically . readTVar
#else
    readRef = readTVarIO
#endif
    writeRef r = atomically . writeTVar r

instance AtomicReference TVar IO where
    modifyAndSwap r f = atomically $! do
        v <- readTVar r
        writeTVar r (f v)
        return $! v

instance Reference TVar STM where
    newRef = newTVar
    readRef = readTVar
    writeRef = writeTVar

instance AtomicReference TVar STM

instance SyncReference TVar STM

instance (Reference r m) => Reference r (ContT r' m) where
    newRef = lift . newRef
    readRef = lift . readRef
    writeRef r = lift . writeRef r

instance (AtomicReference r m) => AtomicReference r (ContT r' m) where
    modifyAndSwap r = lift . modifyAndSwap r

instance (Error e, Reference r m) => Reference r (ErrorT e m) where
    newRef = lift . newRef
    readRef = lift . readRef
    writeRef r = lift . writeRef r
    modifyRef r f = ErrorT $ modifyRef r $ return . fixFrm <=< runErrorT . f
                    where fixFrm = either ((,) Nothing . Left) (second Right)

instance (Error e, AtomicReference r m) => AtomicReference r (ErrorT e m) where
    modifyAndSwap r = lift . modifyAndSwap r

instance (Error e, SyncReference r m) => SyncReference r (ErrorT e m)

instance (Reference r m) => Reference r (IdentityT m) where
    newRef = lift . newRef
    readRef = lift . readRef
    writeRef r = lift . writeRef r
    modifyRef r f = IdentityT $ modifyRef r (runIdentityT . f)

instance (AtomicReference r m) => AtomicReference r (IdentityT m) where
    modifyAndSwap r = lift . modifyAndSwap r

instance (SyncReference r m) => SyncReference r (IdentityT m)

instance (Reference r m) => Reference r (MaybeT m) where
    newRef = lift . newRef
    readRef = lift . readRef
    writeRef r = lift . writeRef r
    modifyRef r f = MaybeT $ modifyRef r $ return . fixMod <=< runMaybeT . f
                    where fixMod = maybe (Nothing, Nothing) (second Just)

instance (AtomicReference r m) => AtomicReference r (MaybeT m) where
    modifyAndSwap r = lift . modifyAndSwap r

instance (SyncReference r m) => SyncReference r (MaybeT m)

instance (Monoid w, Reference r m) => Reference r (RWSTL.RWST r' w s m) where
    newRef = lift . newRef
    readRef = lift . readRef
    writeRef r = lift . writeRef r
    modifyRef r f
        = RWSTL.RWST $ \r' s -> modifyRef r $ \a ->
            fixR =<< RWSTL.runRWST (f a) r' s
          where fixR ((a, b), s, w) = return $! (a, (b, s, w))

instance (Monoid w, AtomicReference r m)
      => AtomicReference r (RWSTL.RWST r' w s m) where
    modifyAndSwap r = lift . modifyAndSwap r

instance (Monoid w, SyncReference r m)
      => SyncReference r (RWSTL.RWST r' w s m)

instance (Monoid w, Reference r m) => Reference r (RWSTS.RWST r' w s m) where
    newRef = lift . newRef
    readRef = lift . readRef
    writeRef r = lift . writeRef r
    modifyRef r f
        = RWSTS.RWST $ \r' s -> modifyRef r $ \a ->
            fixR =<< RWSTS.runRWST (f a) r' s
          where fixR ((a, b), s, w) = return $! (a, (b, s, w))

instance (Monoid w, AtomicReference r m)
      => AtomicReference r (RWSTS.RWST r' w s m) where
    modifyAndSwap r = lift . modifyAndSwap r

instance (Monoid w, SyncReference r m)
      => SyncReference r (RWSTS.RWST r' w s m)

instance (Reference r m) => Reference r (ReaderT r' m) where
    newRef = lift . newRef
    readRef = lift . readRef
    writeRef r = lift . writeRef r
    modifyRef r f = ReaderT $ \r' -> modifyRef r (\a -> runReaderT (f a) r')

instance (AtomicReference r m) => AtomicReference r (ReaderT r' m) where
    modifyAndSwap r = lift . modifyAndSwap r

instance (SyncReference r m) => SyncReference r (ReaderT r' m)

instance (Reference r m) => Reference r (SL.StateT s m) where
    newRef = lift . newRef
    readRef = lift . readRef
    writeRef r = lift . writeRef r
    modifyRef r f
        = SL.StateT $ \s -> modifyRef r (\a -> fixS =<< SL.runStateT (f a) s)
        where fixS ((a, b), s) = return $! (a, (b, s))

instance (AtomicReference r m) => AtomicReference r (SL.StateT s m) where
    modifyAndSwap r = lift . modifyAndSwap r

instance (SyncReference r m) => SyncReference r (SL.StateT s m)

instance (Reference r m) => Reference r (StS.StateT s m) where
    newRef = lift . newRef
    readRef = lift . readRef
    writeRef r = lift . writeRef r
    modifyRef r f
        = StS.StateT $ \s -> modifyRef r (\a -> fixS =<< StS.runStateT (f a) s)
        where fixS ((a, b), s) = return $! (a, (b, s))

instance (AtomicReference r m) => AtomicReference r (StS.StateT s m) where
    modifyAndSwap r = lift . modifyAndSwap r

instance (SyncReference r m) => SyncReference r (StS.StateT s m)

instance (Monoid w, Reference r m) => Reference r (WL.WriterT w m) where
    newRef = lift . newRef
    readRef = lift . readRef
    writeRef r = lift . writeRef r
    modifyRef r f
        = WL.WriterT $ modifyRef r (fixW <=< WL.runWriterT . f)
        where fixW ((a, b), w) = return $! (a, (b, w))

instance (Monoid w, AtomicReference r m)
      => AtomicReference r (WL.WriterT w m) where
    modifyAndSwap r = lift . modifyAndSwap r

instance (Monoid w, SyncReference r m) => SyncReference r (WL.WriterT w m)

instance (Monoid w, Reference r m) => Reference r (WS.WriterT w m) where
    newRef = lift . newRef
    readRef = lift . readRef
    writeRef r = lift . writeRef r
    modifyRef r f
        = WS.WriterT $ modifyRef r (fixW <=< WS.runWriterT . f)
        where fixW ((a, b), w) = return $! (a, (b, w))

instance (Monoid w, AtomicReference r m)
      => AtomicReference r (WS.WriterT w m) where
    modifyAndSwap r = lift . modifyAndSwap r

instance (Monoid w, SyncReference r m) => SyncReference r (WS.WriterT w m)

