{-# LANGUAGE GADTs, RankNTypes, EmptyDataDecls, FlexibleInstances
           , TypeSynonymInstances, ScopedTypeVariables
  #-}

-- Type inference for the all-rules corner of the lambda cube
-- Terms are Church-style (hence the possibility of inference),
-- and represented as parametric HOAS.

module PHOASInf where

import Control.Applicative
import Control.Monad

import Data.Monoid
import Data.List (intercalate)

import Data.Set (Set)
import qualified Data.Set as S

type Index = Integer

newtype MaybeT m a = MT { unMT :: m (Maybe a) }

instance Functor m => Functor (MaybeT m) where
  fmap g = MT . fmap (fmap g) . unMT

instance Applicative m => Applicative (MaybeT m) where
  pure x = MT $ pure (Just x)
  MT mf <*> MT mx = MT $ (\maf max -> maybe Nothing (<$> max) maf) <$> mf <*> mx

instance Monad m => Monad (MaybeT m) where
  return x   = MT $ return (Just x)
  MT m >>= f = MT $ m >>= maybe (return Nothing) (unMT . f)

instance Monad m => MonadPlus (MaybeT m) where
  mzero = MT $ return Nothing
  mplus (MT mx) (MT my) = MT $ mx >>= maybe my (return . Just)

instance Monoid Index where
  mempty  = 0
  mappend = (+)

newtype Name = Name [(String, Int)] deriving (Eq, Ord)

instance Show Name where
  show (Name l) = intercalate "." . reverse . map f $ l
   where f (s, i) = s ++ if i == 0 then "" else show i

data NameSupply = Sup [(String, Int)] Int

bumpNS :: NameSupply -> NameSupply
bumpNS (Sup n i) = Sup n (i + 1)

freshNameSpace :: String -> NameSupply -> NameSupply
freshNameSpace s (Sup n i) = Sup ((s,i):n) 0

mkName :: String -> NameSupply -> Name
mkName s (Sup n i) = Name $ (s, i) : n

class (Applicative m, Monad m) => Supplier m where
  freshName :: String -> (Name  -> m a) -> m a
  forkNames :: String -> m s -> (s -> m t) -> m t

instance Supplier ((->) NameSupply) where
  freshName s f ns = f (mkName s ns) (bumpNS ns)
  forkNames s m g ns = g (m $ freshNameSpace s ns) (bumpNS ns)

instance Supplier m => Supplier (MaybeT m) where
  freshName s f   = MT $ freshName s (unMT . f)
  forkNames s m f = MT $ forkNames s (unMT m) (maybe (return Nothing) (unMT . f))

data Ref b = Name ::: Normalizer (Ref b) b

freshRef :: Supplier m => String -> Normalizer (Ref b) b -> (Ref b -> m a) -> m a
freshRef s k f = freshName s (\n -> f $ n ::: k)

instance Eq (Ref b) where
  n ::: _ == n' ::: _ = n == n'

instance Ord (Ref b) where
  compare (n ::: _) (n' ::: _) = compare n n'

data Quant = Lam | Pi deriving (Eq, Show)

data Term' v b where
  Type  :: Term' v b
  Kind  :: Term' v b
  Bind  :: Quant -> Term' v b -> Scope v b -> Term' v b
  Ap    :: Term' v b -> Term' v b -> Term' v b
  Free  :: v -> Term' v b
  Bound :: b -> Term' v b

data Scope v b where
  (:.) :: String -> (b -> Term' v b) -> Scope v b
  K    :: Term' v b -> Scope v b

newtype Term v = TM (forall b. Term' v b)

mapVar' :: (v -> u) -> Term' v b -> Term' u b
mapVar' _ Type      = Type
mapVar' _ Kind      = Kind
mapVar' _ (Bound b) = Bound b
mapVar' g (Free v)  = Free $ g v
mapVar' g (Ap f x)  = Ap (mapVar' g f) (mapVar' g x)
mapVar' g (Bind q k e) = Bind q (mapVar' g k) (mapVarS g e)

mapVarS :: (v -> u) -> Scope v b -> Scope u b
mapVarS g (K e)    = K $ mapVar' g e
mapVarS g (s :. e) = s :. \b -> mapVar' g $ e b

instance Functor Term where
  fmap g (TM t) = TM (mapVar' g t)

alphaEq :: (Eq v) => Index -> Term' v Index -> Term' v Index -> Bool
alphaEq _ Type         Type       = True
alphaEq _ Kind         Kind       = True
alphaEq _ (Free v)     (Free v')  = v == v'
alphaEq _ (Bound i)    (Bound i') = i == i'
alphaEq i (Ap f x)     (Ap f' x') = alphaEq i f f' && alphaEq i x x'
alphaEq i (Bind q k s) (Bind q' k' s') =
  q == q' && alphaEq i k k' && alphaEqS i s s'

alphaEqS :: (Eq v) => Index -> Scope v Index -> Scope v Index -> Bool
alphaEqS i (K    e) (K    e') = alphaEq i e e'
alphaEqS i (_ :. e) (_ :. e') = alphaEq (i+1) (e i) (e' i)
alphaEqS i (_ :. e) (K    e') = alphaEq (i+1) (e i) e'
alphaEqS i (K    e) (_ :. e') = alphaEq (i+1) e     (e' i)

instance (Eq v) => Eq (Term v) where
  TM t == TM t' = alphaEq 0 t t'

freeVariables' :: (Ord v) => b -> Term' v b -> Set v
freeVariables' _ (Free v)            = S.singleton v
freeVariables' b (Ap f x)            = freeVariables' b f `S.union` freeVariables' b x
freeVariables' b (Bind _ k (K    e)) = freeVariables' b k `S.union` freeVariables' b e
freeVariables' b (Bind _ k (_ :. e)) = freeVariables' b k `S.union` freeVariables' b (e b)
freeVariables' _ _                   = S.empty

freeVariables :: forall v. (Ord v) => Term v -> Set v
freeVariables (TM t) = freeVariables' () t

abstract :: (Monoid b, Ord v) => String -> v -> Term' v b -> Scope v b
abstract s v e
  | v `S.notMember` freeVariables' mempty e = K e
  | otherwise                               = s :. \b -> pull b e
 where
 pull b (Free v') | v == v'  = Bound b
 pull b (Ap f x)             = Ap (pull b f) (pull b x)
 pull b (Bind q k (K e))     = Bind q (pull b k) (K $ pull b e)
 pull b (Bind q k (s' :. e)) = Bind q (pull b k) (s' :. \b' -> pull b (e b'))
 pull b t                    = t

dedeBruijn :: Index -> [b] -> Term' v Index -> Term' v b
dedeBruijn _ _  Type                = Type
dedeBruijn _ _  Kind                = Kind
dedeBruijn _ _  (Free v)            = Free v
dedeBruijn _ bs (Bound i)           = Bound $ reverse bs !! fromIntegral i
dedeBruijn i bs (Ap f x)            = Ap (dedeBruijn i bs f) (dedeBruijn i bs x)
dedeBruijn i bs (Bind q k (K e))    = Bind q (dedeBruijn i bs k) (K $ dedeBruijn i bs e)
dedeBruijn i bs (Bind q k (s :. e)) =
  Bind q (dedeBruijn i bs k) (s :. \b -> dedeBruijn (i+1) (b:bs) (e i))

-- Normalization

data Norm v b where
  NZ :: b -> Norm v b
  NS :: Normalizer v b -> Norm v b

type Normalizer v b = Term' v (Norm v b)

reduce :: Bool -> Normalizer v b -> Normalizer v b
reduce _   (Bound (NZ b))       = Bound $ NZ b
reduce cbv (Bound (NS t))       = reduce cbv t
reduce cbv (Bind Pi k (K e))    = Bind Pi (reduce cbv k) (K $ reduce cbv e)
reduce cbv (Bind Pi k (s :. e)) = Bind Pi (reduce cbv k) (s :. \v -> reduce cbv $ e v)
reduce cbv (Ap f x) = case reduce cbv f of
  Bind Lam _ (K e)    -> reduce cbv e
  Bind Lam _ (_ :. e) -> reduce cbv . e . NS $ x'
  f                   -> Ap f x'
 where x' | cbv       = reduce cbv x
          | otherwise = x
reduce cbv (Bind Lam k sc)
  | cbv       = Bind Lam (reduce cbv k) $
                  case sc of
                    s :. e -> s :. \v -> reduce cbv $ e v
                    K    e -> K $ reduce cbv e
  | otherwise = Bind Lam (reduce cbv k) sc
reduce _ t = t

cut :: Normalizer v b -> Term' v b
cut Type                = Type
cut Kind                = Kind
cut (Free v)            = Free v
cut (Bound (NZ b))      = Bound b
cut (Bound (NS t))      = cut t
cut (Bind q k (K e))    = Bind q (cut k) (K $ cut e)
cut (Bind q k (s :. e)) = Bind q (cut k) (s :. \v -> cut . e $ NZ v)
cut (Ap f x)            = Ap (cut f) (cut x)

nf :: Term v -> Term v
nf (TM t) = TM (cut . reduce True $ t)

whnf :: Term v -> Term v
whnf (TM t) = TM (cut . reduce False $ t)

-- Type inference/checking

type Checker  = Normalizer (Ref Index) Index
type CheckerS = Scope (Ref Index) (Norm (Ref Index) Index)

infer' :: (Supplier m, MonadPlus m) => Checker -> m Checker
infer' Type             = return Kind
infer' Kind             = fail "Cannot take the sort of Kind."
infer' (Bound (NZ _))   = fail "impossible!"
infer' (Bound (NS t))   = Bound . NS <$> infer' t
infer' (Free (_ ::: k)) = return k
infer' (Ap f x)         = infer' f >>= \tf ->
  case reduce False tf of
    Bind Pi k sc -> do
      tx <- infer' x
      guard $ alphaEq 0 (cut . reduce True $ k) (cut . reduce True $ tx)
      case sc of
        _ :. e -> return $ e (NS x)
        K    e -> return e
infer' (Bind Pi k sc)   = do
  tk <- infer' k
  te <- sortScope k sc
  case (cut $ reduce True tk, cut $ reduce True te) of
    (Type, Type) -> return Type
    (Kind, Type) -> return Type
    (Type, Kind) -> return Kind
    (Kind, Kind) -> return Kind
    (_   , _   ) -> fail "Non-sort function space."
infer' (Bind Lam k sc)  = do
  tsc <- inferS k sc
  let ty = Bind Pi k tsc
  infer' ty
  return ty

sortScope :: (Supplier m, MonadPlus m) => Checker -> CheckerS -> m Checker
sortScope _ (K e)    = infer' e
sortScope k (s :. e) = freshRef s k $ \r -> infer' . e . NS $ Free r

abstractN s r e
  | r `S.notMember` freeVariables' mempty (cut e) = K e
  | otherwise                                     = s :. \b -> pull b e
 where
 pull b (Free r') | r == r' = Bound b
 pull _ (Bound (NZ b))      = Bound (NZ b)
 pull b (Bound (NS t))      = Bound . NS $ pull b t
 pull b (Bind q k (K    e)) = Bind q (pull b k) $ K (pull b e)
 pull b (Bind q k (s :. e)) = Bind q (pull b k) $ s :. \b' -> pull b $ e b'
 pull b (Ap f x)            = Ap (pull b f) (pull b x)
 pull _ t                   = t

inferS :: (Supplier m, MonadPlus m) => Checker -> CheckerS -> m CheckerS
inferS _ (K e)    = K <$> infer' e
inferS k (s :. e) = freshRef s k $ \r -> do
                      te <- infer' . e . NS $ Free r
                      return $ abstractN s r te

infer :: (Supplier m, MonadPlus m) => (Name -> Checker) -> Term Name -> m (Term Name)
infer env (TM t) = do ty <- flip (forkNames "inf") return $
                              infer' (mapVar' (\n -> n ::: env n) t)
                      return $ TM (dedeBruijn 0 []
                                   . mapVar' (\(n ::: _) -> n)
                                   . cut $ ty)

isp = Sup [("inf",0)] 0

-- printing

parens s = showString "(" . s . showString ")"

space = showString " "
colon = showString " : "
arrow = showString " -> "
lam   = showString "\\"

pretty :: (Supplier m) => Term' Name Name -> m ShowS
pretty (Free n)  = return $ shows n
pretty (Bound n) = return $ shows n
pretty (Ap f x)  = do sf <- pretty f
                      sx <- pretty x
                      return $ parens sf . space . parens sx
pretty Type      = return $ showString "Type"
pretty Kind      = return $ showString "Kind"
pretty (Bind Pi k (K e)) = do sk <- pretty k
                              se <- pretty e
                              return $ sk . arrow . se
pretty (Bind Pi k (s :. e)) = freshName s $ \n -> do
  sk <- pretty k
  se <- pretty $ e n
  return $ parens (shows n . colon . sk) . arrow . se
pretty (Bind Lam k (K e)) = do
  sk <- pretty k
  se <- pretty e
  return $ lam . parens (showString "_" . colon . sk) . arrow . se
pretty (Bind Lam k (s :. e)) = freshName s $ \n -> do
  sk <- pretty k
  se <- pretty $ e n
  return $ lam . parens (shows n . colon . sk) . arrow . se

psp = Sup [("pp", 0)] 0

pp t = pretty t psp ""
