-- GRIN-like backend for Yhc Core.
-- Global variable name substitutions.

module Yhc.Core.GRIN.SubstVars where

import Yhc.Core.GRIN.Type

import Data.List
import qualified Data.Set as S
import qualified Data.Map as M

-- Naming conventions: regular nodes start with 'n', shared nodes start with 's',
-- data constructor tags start with 'C@', thunk tags start with 'F@', partial
-- application tags start with 'P@'.

mkTagName s = 'C':'@':s

isRegNode ('n':_) = True
isRegNode _ = False

isShrNode ('s':_) = True
isShrNode _ = False

isNode x = isRegNode x || isShrNode x

isCtorTag ('C':'@':_) = True
isCtorTag _ = False

isThunkTag ('F':'@':_) = True
isThunkTag _ = False

isPapTag ('P':'@':_) = True
isPapTag _ = False


-- Operations over variable names

gVar = GVal . GVar

unGVar (GVal (GVar v)) = v
unGVar z = "#"

unVar (GVar v) = v
unVar _ = "%"

val2sval (GVal s) = s
val2sval z = error $ "not a sval: " ++ show z


-- Substitute variable(s). Given a map of substitution,
-- if a variable occurs in a simple expression, replace it.

substGRIN :: M.Map GName GName -> GRIN -> GRIN

substGRIN sm (g@GRIN {gFuncs = gfs}) = g {gFuncs = map (substFunc sm) gfs}

substFunc :: M.Map GName GName -> GFunc -> GFunc

substFunc sm (gf@GFunc {gFuncArgs = as, gFuncBody = gb}) = 
  gf {gFuncArgs = map (substGN sm) as, gFuncBody = map (substGX sm) gb}

substGX :: M.Map GName GName -> GExpr -> GExpr

substGX sm (GBind sx bv) = GBind (substSX sm sx) (substGV sm bv)
substGX sm (GSimple sx) = GSimple (substSX sm sx)
substGX sm (GCase cv cbrs) = 
  let pats = map fst cbrs
      blks = map snd cbrs
      pats' = map (substPT sm) pats
      blks' = map (map (substGX sm)) blks
  in  GCase (substGV sm cv) (zip pats' blks') 
                                              
substGX sm gx = gx

substPT :: M.Map GName GName -> GPat -> GPat

substPT sm (GPatTag gn) = GPatTag (substGN sm gn)
substPT sm pt = pt

substSX :: M.Map GName GName -> GSexpr -> GSexpr

substSX sm (GUnit gv) = GUnit (substGV sm gv)
substSX sm (GCall gn gvs) = GCall (substGN sm gn) (map (substGV sm) gvs)
substSX sm (GStore gv) = GStore (substGV sm gv)
substSX sm (GEval gn) = GEval (substGN sm gn)
substSX sm (GApply gn gv) = GApply (substGN sm gn) (substGV sm gv)
substSX sm (GFetch gn tag mbi) = GFetch (substGN sm gn) (substGN sm tag) mbi
substSX sm (GUpdate gn gv) = GUpdate (substGN sm gn) (substGV sm gv)
substSX sm (GInline bl) = GInline (map (substGX sm) bl)

substGV :: M.Map GName GName -> GVal -> GVal

substGV sm (GTagged gn svs) = GTagged (substGN sm gn) (map (substSV sm) svs)
substGV sm (GTag gn) = GTag (substGN sm gn)
substGV sm (GVal sv) = GVal (substSV sm sv)
substGV sm gv = gv

substSV :: M.Map GName GName -> GSval -> GSval
substSV sm (GVar gn) = GVar (substGN sm gn)
substSV sm sv = sv

substGN :: M.Map GName GName -> GName -> GName

substGN sm gn = M.findWithDefault gn gn sm

-- A helper function to find out which variables are used by a simple expression.

useVars :: GSexpr -> S.Set GName

useVars (GUnit gv) = case gv of
  (GVal (GVar _)) -> S.singleton (unGVar gv)
  (GTagged gn gvs) -> S.fromList $ map unVar gvs
  _               -> S.empty
useVars (GCall gn gvs) = S.fromList (gn : map unGVar gvs)
useVars (GStore gv) = case gv of
  (GVal (GVar _)) -> S.singleton (unGVar gv)
  (GTagged gn gvs) -> S.fromList (gn : map unVar gvs)
  (GTag gn) -> S.singleton gn
  _               -> S.empty
useVars (GEval gn) = S.singleton gn
useVars (GApply gn gv) = S.fromList [gn, unGVar gv]
useVars (GFetch gn _ _) = S.singleton gn
useVars (GUpdate gn gv) = S.fromList [gn, unGVar gv]
useVars (GInline bl) = useVarBlk S.empty bl

-- A helper function to find out which variables are used by a block.

varUseBlock :: GBlock -> S.Set GName

varUseBlock bl = useVarBlk S.empty bl

useVarBlk uv [] = uv
useVarBlk uv (b:bs) = case b of
  GCase cv cbrs ->
    let uv' = S.unions (S.singleton (unGVar cv) : map (useVarBlk S.empty . snd) cbrs)
    in  useVarBlk (uv `S.union` uv') bs
  GSimple sx -> useVarBlk (uv `S.union` useVars sx) bs
  GBind sx bv ->
    let uv' = S.unions [uv, S.singleton (unGVar bv), useVars sx]
    in  useVarBlk uv' bs

-- Map a function over a block descending into nested blocks and case branches.
-- The function mapped over takes a GExpr and returns a GExpr which replaces 
-- the original GExpr. Then the map is done over the rest of the block.

mapOverBlock :: (GExpr -> GExpr) -> GBlock -> GBlock

mapOverBlock f [] = []

mapOverBlock f (e:es) = e' : mapOverBlock f es where 
  e' = case e of
    GCase gv cbrs ->
      let pats = map fst cbrs
          cbls = map (mapOverBlock f . snd) cbrs
          newp = zip pats cbls
      in  f (GCase gv newp)
    GBind (GInline ib) bv ->
      let nbl = mapOverBlock f ib
      in  f (GBind (GInline nbl) bv)
    GSimple (GInline ib) ->
      let nbl = mapOverBlock f ib
      in  f (GSimple (GInline nbl))
    ex -> f ex

