module Typechecker (typecheck, TypeDict) where
import Preprocessor
import Data.Generics
import Data.Map ((!), Map)
import qualified Data.Map as Map
import Control.Monad.Error
import Control.Monad.State
import Control.Monad.Reader
import Data.Maybe (fromJust)

type TypeDict = Map Exp Type
type Typechecker =  ReaderT Context (StateT TypeDict (Either String))
data Context = Context { fns :: Program, vars :: [(String,Type)] }

typecheck :: Program -> Either String TypeDict
typecheck p = do
  when (not $ terminates p) $ fail "Recursive reference detected" -- FIXME: better message
  case Map.lookup "_main" p of
    Just exp -> flip execStateT Map.empty
                $ flip runReaderT Context { fns = p, vars = []}
                $ mapM_ (inferType . EGlobal) (Map.keys p) >> checkType TInt (EGlobal "_main")
    Nothing -> Left "missing 'main' label"

-- a side effect of the type checker is to store the types of top-level
-- definitions and of lambdas in a dictionary, so that the Compiler 
-- module can use them to generate LLVM types.
memoType exp inferRule = do
  typeDict <- get
  case Map.lookup exp typeDict of
    Just t -> return t
    Nothing -> do t <- inferRule
                  learnType exp t
                  
learnType :: Exp -> Type -> Typechecker Type
learnType exp t = do
  typeDict <- get
  put $ Map.insert exp t typeDict
  return t

inferType exp@(EGlobal fname) = memoType exp $ do
  ctx <- ask
  case Map.lookup fname (fns ctx) of
    Just fbody -> inferType fbody
    Nothing -> fail $ "no such function " ++ fname
inferType (ELocal var) = do ctx <- ask
                            return $ fromJust $ lookup var (vars ctx)
inferType (EInt n) = return TInt
inferType (EBool b) = return TBool
inferType (EApp f arg) = do
  funcType <- inferType f
  case funcType of
    TFun t1 t2 -> checkType t1 arg >> return t2
    badType -> fail $ show f ++ " is used as a function but " ++
                      show badType ++ " is not a function type"
  
inferType (ENeg e) = checkType TInt e

inferType (EArith _ e1 e2) = checkType TInt e1 >> checkType TInt e2
inferType (EComp _ e1 e2)  = do
  t <- inferType e1
  checkType t e2 
  return TBool
inferType exp@(ELambda arg@(_, argType) body) = memoType exp $
  local (\ctx -> ctx { vars = arg : vars ctx} ) $ do
    bodyType <- inferType body
    return $ TFun argType bodyType
inferType exp@(EClosure _ arg@(_, argType) body) = memoType exp $
  local (\ctx -> ctx { vars = arg : vars ctx} ) $ do
    bodyType <- inferType body
    return $ TFun argType bodyType
inferType exp@(EIf cond e1 e2) = memoType exp $ do
  checkType TBool cond
  t <- inferType e1
  checkType t e2

checkType :: Type -> Exp -> Typechecker Type
checkType t e = do
  t' <- inferType e
  if t' == t
    then return t
    else fail $ show e ++ " should have type " ++ show t ++ " but has type " ++ show t'

-- Disallow recursive references.
terminates :: Program -> Bool
terminates p = null $ recursiveLoops p "_main"

recursiveLoops :: Program -> String -> [[String]]
recursiveLoops p fname = recLoops fname [fname] where
  recLoops f stack =
    concat $ (flip map) (calls (p ! f)) $ \g ->
      if g `elem` stack
        then [ reverse (g:stack) ]
        else recLoops g (g:stack)
  -- Count any reference to `f' as a call, since in general the analysis would
  -- be undecidable.
  calls :: Exp -> [String]
  calls (EGlobal f) = [f]
  calls exp = concat $ gmapQ (mkQ [] calls) exp