-- C parser

module Parser where

import Text.ParserCombinators.Parsec
import Text.ParserCombinators.Parsec.Expr
import Text.ParserCombinators.Parsec.Token
import Text.ParserCombinators.Parsec.Language

import AST

expr :: Parser ExprAt
expr = buildExpressionParser table factor <?> "expression"

-- This excludes the comma operator from the expression parser unless the expression is in parentheses
-- This is needed in places where the comma can be just a separator instead of an operator such as in
-- function calls or in variable initialisation.
noCommaExpr =
   parens lexer expr
   <|> (do{ pos <- getPosition; es <- braces lexer (commaSep lexer noCommaExpr); return (ExprAt pos (ArrayInit es))})
   <|> buildExpressionParser (init table) factor <?> "noCommaExpr"

table = [[functionCall, arraySubscript, select "." SelectByReference, select "->" SelectThroughPointer
         ,postOp "++" PostIncrement, postOp "--" PostDecrement]
        ,[preOp "*" (Unary Dereference), preOp "&" (Unary AddressOf), preOp "+" (Unary UnaryPlus)
         ,preOp "-" (Unary UnaryMinus), preOp "!" (Unary LogicalNot), preOp "~" (Unary BitwiseComplement)
         ,preOp "++" (Unary PreIncrement), preOp "--" (Unary PreDecrement), typecast, preOp "sizeof" (Unary Sizeof)]
        ,[op "*" (Binary Times) AssocLeft, op "/" (Binary Divide) AssocLeft, op "%" (Binary Mod) AssocLeft]
        ,[op "+" (Binary Plus) AssocLeft, op "-" (Binary Minus) AssocLeft]
        ,[op ">>" (Binary ShiftRight) AssocLeft, op "<<" (Binary ShiftLeft) AssocLeft]
        ,[op "<" (Binary LessThan) AssocLeft, op ">" (Binary GreaterThan) AssocLeft,
          op "<=" (Binary LessThanOrEq) AssocLeft, op ">=" (Binary GreaterThanOrEq) AssocLeft]
        ,[op "==" (Binary EqualTo) AssocLeft, op "!=" (Binary NotEqualTo) AssocLeft]
        ,[op "&" (Binary BitwiseAnd) AssocLeft]
        ,[op "^" (Binary BitwiseXOr) AssocLeft]
        ,[op "|" (Binary BitwiseOr) AssocLeft]
        ,[op "&&" (Binary LogicalAnd) AssocLeft]
        ,[op "||" (Binary LogicalOr) AssocLeft]
        ,[conditional]
        ,[op "=" (Assignment Assign) AssocRight, op "+=" (Assignment IncreaseBy) AssocRight,
          op "-=" (Assignment DecreaseBy) AssocRight, op "*=" (Assignment MultiplyBy) AssocRight,
          op "/=" (Assignment DivideBy) AssocRight, op "%=" (Assignment ModBy) AssocRight,
          op ">>=" (Assignment ShiftRightBy) AssocRight, op "<<=" (Assignment ShiftLeftBy) AssocRight,
          op "&=" (Assignment BitwiseAndBy) AssocRight, op "|=" (Assignment BitwiseOrBy) AssocRight]
         ,[op "," Comma AssocLeft]
        ]
   where op s f assoc = Infix (do{ pos <- getPosition; reservedOp lexer s; return (\x y -> ExprAt pos (f x y))} <?> "operator") assoc
         preOp s f = Prefix (do{ pos <- getPosition; reservedOp lexer s; return ((ExprAt pos).f)} <?> "prefix operator")
         postOp s f = Postfix (do{ pos <- getPosition; reservedOp lexer s; return ((ExprAt pos).f)} <?> "postfix operator")
         typecast = Prefix $ try (do{ pos <- getPosition; (t,stars) <- parens lexer (do{ t<-typename; stars<-many (symbol lexer "*"); return (t,stars)}); return ((ExprAt pos).(Unary (Cast t (length stars))))} <?> "typecast")
         functionCall = Postfix (do{ pos <- getPosition; args <- parens lexer (commaSep lexer noCommaExpr);
                                     return ((ExprAt pos).(`FunctionCall` args))} <?> "functionCall")
         arraySubscript = Postfix (do{ pos <- getPosition; e <- squares lexer expr; return ((ExprAt pos).(`ArraySubscript` e))} <?> "arraySubscript")
         select s f = Postfix (do{ pos <- getPosition; reservedOp lexer s; field <- identifier lexer;
                                   return ((ExprAt pos).((flip f) field))} <?> "select")
         conditional = Infix (do{ pos <- getPosition; reservedOp lexer "?"; a <- expr; reservedOp lexer ":";
                                  return (\c b -> ExprAt pos (Conditional c a b))} <?> "conditional") AssocRight

typeQualifier =
   try (reserved lexer "const" >> return Const)
   <|> try (reserved lexer "volatile" >> return Volatile)
   <|> try (reserved lexer "signed" >> return Signed)
   <|> try (reserved lexer "unsigned" >> return Unsigned)
   <?> "typeQualifier"

typename = do
   qs <- many typeQualifier
   t <- typename'
   return (Type qs t)

typename' =
   try (reserved lexer "char" >> return CharType)
   <|> try (reserved lexer "short" >> return ShortType)
   <|> try (reserved lexer "long" >> return LongType)
   <|> try (reserved lexer "int" >> return IntType)
   <|> try (reserved lexer "float" >> return FloatType)
   <|> try (reserved lexer "double" >> return DoubleType)
   <|> try (do{ nm <- identifier lexer; return (CustomType nm)})
   <?> "typename"

factor =
   parens lexer expr
   <|> try (do {pos <- getPosition; s <- stringLiteral lexer; return (ExprAt pos $ ConstExpr $ StringConst s)})
   <|> try (do {pos <- getPosition; c <- charLiteral lexer; return (ExprAt pos $ ConstExpr $ CharConst c)})
   <|> try (do {pos <- getPosition; x <- float lexer; return (ExprAt pos $ ConstExpr $ FloatConst $ doubleToFloat x)})
   <|> try (do {pos <- getPosition; n <- integer lexer; return (ExprAt pos $ ConstExpr $ ShortConst $ fromInteger n)})
   <|> try (do {pos <- getPosition; i <- identifier lexer; return (ExprAt pos $ Identifier i)})
   <?> "simple expression"

doubleToFloat = fromRational.toRational

------------

statement = 
   braces lexer (do{ s <- many statement; return (Block s)})
   <|> try (do{ d <- declaration; return (DeclStatement d)})
   <|> try (do{ e <- expr; semi lexer; return (ExpStatement e)})
   <|> try ifStatement
   <|> try whileStatement
   <|> try doWhileStatement
   <|> try forStatement
   <|> try returnStatement
   <|> do{ semi lexer; return EmptyStatement}
   <?> "statement"

declaration = do
   mods <- many declModifier
   tn <- typename
   terms <- commaSep lexer declarationTerm
   semi lexer
   return (Declaration mods tn terms) <?> "declaration"

declarationTerm = do
   stars <- many (symbol lexer "*")
   pos <- getPosition
   id <- identifier lexer
   arr <- option False (try (do{ open <- symbol lexer "["; close <- symbol lexer "]"; return True}))
   arraySizesList <- many (squares lexer expr)
   init <- option Nothing (do{ reservedOp lexer "="; e <- noCommaExpr; return (Just e)})
   return (DeclTerm ((length stars) + (if arr then 1 else 0)) id arraySizesList init pos) <?> "declarationTerm"

declModifier =
   (reserved lexer "static" >> return Static)
   <|> try (reserved lexer "extern" >> return Extern)
   <?> "declModifier"

ifStatement = do
   reserved lexer "if"
   cond <- parens lexer expr
   a <- codeBlock
   b <- option Nothing (do{ reserved lexer "else"; x <- codeBlock; return (Just x)})
   return (IfStatement cond a b) <?> "ifStatement"

whileStatement = do
   reserved lexer "while"
   cond <- parens lexer expr
   a <- codeBlock
   return (While cond a) <?> "whileStatement"

doWhileStatement = do
   reserved lexer "do"
   a <- codeBlock
   reserved lexer "while"
   cond <- parens lexer expr
   semi lexer
   return (DoWhile a cond) <?> "doWhileStatement"

forStatement = do
   reserved lexer "for"
   (init,cond,inc) <- parens lexer (do{ a <- expr; semi lexer; b <- expr; semi lexer; c <- expr; return (a,b,c)})
   body <- codeBlock
   return (For init cond inc body) <?> "forStatement"

returnStatement = do
   reserved lexer "return"
   val <- option Nothing (do{ a <- try expr; return (Just a)})
   semi lexer
   return (Return val) <?> "returnStatement"

codeBlock =
   try (do{ s <- statement; return [s]})
   <|> do{ s <- braces lexer (many statement); return s}
   <?> "codeBlock"

------------

global =
   try (do{ d <- declaration; return (GlobalDecl d)})
   <|> try functionDeclaration
   <|> try functionDefinition
   <|> try preprocessor
   <|> try typedef
   <?> "global"

functionDeclaration = do
   t <- functionReturnType
   pos <- getPosition
   name <- identifier lexer
   params <- parens lexer (commaSep lexer parameter)
   semi lexer
   return (FunctionDecl t name params pos) <?> "functionDeclaration"

parameter = do
   t1 <- typename
   stars <- many (symbol lexer "*")
   pos <- getPosition
   id <- identifier lexer
   --arraySizesList <- many (squares lexer expr)
   return (t1,length stars,id,pos) <?> "parameter"

functionDefinition = do
   t <- functionReturnType
   pos <- getPosition
   name <- identifier lexer
   params <- parens lexer (commaSep lexer parameter)
   body <- braces lexer (many statement)
   return (FunctionDef t name params body pos) <?> "functionDefinition"

functionReturnType =
   try (reserved lexer "void" >> return Nothing)
   <|> try (do {t <- typename; return (Just t)})
   <?> "functionReturnType"

typedef = do
   reserved lexer "typedef"
   t <- typename
   stars <- many (symbol lexer "*")
   pos <- getPosition
   i <- identifier lexer
   arraySizesList <- many (squares lexer expr)
   semi lexer
   return (Typedef t (length stars) i arraySizesList pos) <?> "typedef"

preprocessor = do
   pos <- getPosition
   char '#'
   s <- manyTill anyChar newline
   return (Preprocessor ('#':s) pos) <?> "preprocessor"

------------

file = many (spaces >> global)

------------

cStyle = javaStyle
   { reservedOpNames = [".","->","++","--","*","&","+","-","!","~","sizeof","/","%","<<",">>","<",">","<=",">=","==","!=",
                        "^","|","&&","||","?",":"]
   , reservedNames = ["char","short","long","int","float","double","void","const","static","extern","if"
                     ,"else","do","while","for","break","continue","return","typedef","unsigned","signed"] }

lexer :: TokenParser ()
lexer  = makeTokenParser cStyle
