{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- | A monad to handle the necessary state management LLVM code generation. Allows
--   for nested function definitions. LLVM type functions are also placed here.
module CodeGenMonad (
    Function, Value, Module, BasicBlock, 
    CodeGen(..), runCodeGen,
    newFunction, newNamedFunction, defineFunction, call,
    getFunction, withCurrentBuilder,
    newBasicBlock, defineBasicBlock,
    makeType, constFunType, makeStruct, closureType, opaquePtr, showType, genericPrimitiveType,
    getEnv,
    bitcast
) where
    
import Preprocessor
import LLVM.Core (Linkage(..))
import qualified LLVMUtils as U
import LLVMUtils (Module)
import qualified LLVM.FFI.Core as FFI

import Control.Applicative ((<$>))
import Control.Monad (liftM, liftM2)
import Control.Monad.Trans (liftIO)
import Control.Monad.State

import Data.List (intercalate)
import Foreign.Marshal.Array (withArray, allocaArray, peekArray, newArray)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromJust)

type Value = FFI.ValueRef
type Function = FFI.ValueRef
type BasicBlock = FFI.BasicBlockRef

data CGFState = CGF { cg_builder :: U.Builder,
                      cg_function :: Function,
                      cg_env :: Map String (CodeGen Value)
                    }

data CodeGenState = CGState {
    cg_module   :: Module,
    cg_next_sym :: !Int,
    -- stack of per-function states, for nested function definitions (lambdas)
    cg_function_state :: [CGFState]
    }
newtype CodeGen a = CodeGen (StateT CodeGenState IO a)
    deriving (Functor, Monad, MonadState CodeGenState, MonadIO)

-- | Run function for the monad. Returns a reference to the created Module.
runCodeGen :: String -> CodeGen a -> IO (Module, a)
runCodeGen moduleName (CodeGen body) = do
  m <- U.createModule moduleName
  let cg = CGState { cg_module = m, cg_next_sym = 1, cg_function_state = [] }
  result <- evalStateT body cg
  return (m, result)

-- For debugging. Mostly taken from the LLVM bindings.
showType :: FFI.TypeRef -> IO String
showType p = do
    pk <- FFI.getTypeKind p
    case pk of
        FFI.VoidTypeKind -> return "()"
        FFI.FloatTypeKind -> return "Float"
        FFI.DoubleTypeKind -> return "Double"
        FFI.X86_FP80TypeKind -> return "X86_FP80"
        FFI.FP128TypeKind -> return "FP128"
        FFI.PPC_FP128TypeKind -> return "PPC_FP128"
        FFI.LabelTypeKind -> return "Label"
        FFI.IntegerTypeKind -> do w <- FFI.getIntTypeWidth p; return $ "(IntN " ++ show w ++ ")"
        FFI.FunctionTypeKind -> do
                  r <- FFI.getReturnType p
                  c <- FFI.countParamTypes p
                  let n = fromIntegral c
                  as <- allocaArray n $ \ args -> do
                     FFI.getParamTypes p args
                     peekArray n args
                  ts <- mapM showType (as ++ [r])
                  return $ "(" ++ intercalate " -> " ts ++ ")"
        FFI.StructTypeKind -> return $ "{ " ++ show (FFI.countStructElementTypes p) ++ " fields }"
        FFI.ArrayTypeKind -> do n <- FFI.getArrayLength p; t <- FFI.getElementType p >>= showType; return $ "(Array " ++ show n ++ " " ++ t ++ ")"
        FFI.PointerTypeKind -> do t <- FFI.getElementType p >>= showType; return $ "(Ptr " ++ t ++ ")"
        FFI.OpaqueTypeKind -> return "Opaque"
        FFI.VectorTypeKind -> do n <- FFI.getVectorSize p; t <- FFI.getElementType p >>= showType; return $ "(Vector " ++ show n ++ " " ++ t ++ ")"

-- | Type of a constant function with a given result type
constFunType :: FFI.TypeRef -> FFI.TypeRef
constFunType result = U.functionType False result  []

-- 
makeType :: Type -> IO FFI.TypeRef
makeType TGeneric = return opaquePtr
makeType TNat = return $ FFI.integerType 32
makeType (TFun a b) = do argT    <- makeType a
                         resultT <- makeType b
                         flip FFI.pointerType 0 <$> closureType argT resultT

-- | Type of a closure object (pointer to a struct; can be an empty struct)
closureType :: FFI.TypeRef -> FFI.TypeRef -> IO FFI.TypeRef
closureType argT resultT =
  makeStruct [ flip FFI.pointerType 0 $ U.functionType False resultT [opaquePtr, argT], opaquePtr ]

-- | Type of an opaque pointer (think void*)
opaquePtr :: FFI.TypeRef
opaquePtr = FFI.pointerType (FFI.integerType 8) 0 -- char*

-- | Type of a struct with an ordered list of fields
makeStruct :: [FFI.TypeRef] -> IO FFI.TypeRef
makeStruct types = liftIO $ do
  typeArray <- newArray types
  return $ FFI.structType typeArray (fromIntegral $ length types)
                                    (fromIntegral $ fromEnum False)
                           
genericPrimitiveType :: Int -> FFI.TypeRef
genericPrimitiveType arity = U.functionType False opaquePtr (replicate arity opaquePtr)

-- | Declare an anonymous function (name will be randomly chosen)
newFunction :: FFI.TypeRef -> CodeGen Function
newFunction typ = genSym "lambda" >>= newNamedFunction typ

-- | Declare a named function
newNamedFunction :: FFI.TypeRef -> String -> CodeGen Function
newNamedFunction typ name = do
    modul <- getModule
    liftIO $ U.addFunction modul ExternalLinkage name typ

-- | Define a function. Note that the body and environment take advantage of
--   deferred evaluation
defineFunction :: Function -> CodeGen Value -> Map String (CodeGen Value) -> CodeGen ()
defineFunction fn body env = do
    bld <- liftIO $ U.createBuilder
    -- push new function state
    modify $ \st -> st { cg_function_state = CGF { cg_builder=bld,
                                                   cg_function=fn,
                                                   cg_env=env } : cg_function_state st}
    defineBasicBlock =<< newBasicBlock
    result <- body
    result' <- bitcast result opaquePtr
    withCurrentBuilder $ \bld -> FFI.buildRet bld result'
    -- pop function state
    modify $ \st -> st { cg_function_state = tail $ cg_function_state st }

-- | Make a function call. Surprisingly simple.
call :: Function -> [Value] -> CodeGen Value
call fn args = withCurrentBuilder $ \bld -> U.makeCall fn bld args
  
newBasicBlock :: CodeGen BasicBlock
newBasicBlock = do
  name <- genSym "L"
  fn <- getFunction
  liftIO $ U.appendBasicBlock fn name

defineBasicBlock :: BasicBlock -> CodeGen ()
defineBasicBlock l = do
    bld <- getBuilder
    liftIO $ U.positionAtEnd bld l

{- Accessors for monad state -}

getFunction :: CodeGen Function
getFunction = gets $ cg_function . head . cg_function_state

getBuilder :: CodeGen U.Builder
getBuilder = gets $ cg_builder . head . cg_function_state

-- convenience function
withCurrentBuilder :: (FFI.BuilderRef -> IO a) -> CodeGen a
withCurrentBuilder body = do
  bld <- gets $ cg_builder . head . cg_function_state
  liftIO $ U.withBuilder bld body

getModule :: CodeGen Module
getModule = gets cg_module

genSym :: String -> CodeGen String
genSym prefix = do
    s <- get
    let n = cg_next_sym s
    put (s { cg_next_sym = n + 1 })
    return $ "_" ++ prefix ++ show n

getEnv :: String -> CodeGen Value
getEnv var = do env <- gets $ cg_env . head . cg_function_state
                case Map.lookup var env of
                	Nothing -> fail $ "internal error: unknown var " ++ var ++ ". Environment contains "
                					  ++ show (Map.keys env)
                	Just code -> code

bitcast :: Value -> FFI.TypeRef -> CodeGen Value
bitcast val targetType = 
    withCurrentBuilder $ \ bldPtr ->
        U.withEmptyCString $ FFI.buildBitCast bldPtr val targetType