--  HasCas -- A symbolic manipulation package in Haskell.

--  Datatype and basic operations for symbolic expressions as well as
--  pretty printing routines.

module HasCas.Expression (SymbolicName,
                          ExactNumber, InexactNumber,
                          BinaryOperator(..),
                          SymbolicExpression(..)) where

import qualified Data.Ratio


--  Expression definitions.
--  ----------------------------------------------------------------------------


--  A symbolic name representing a variable or function.

type SymbolicName = String


--  Data types used for integers and "inexact" floating point numbers.

type ExactNumber = Integer
type InexactNumber = Double


--  A symbolic expression consisting of sub-parts put together via operations.

data BinaryOperator = Add | Subtract | Multiply | Divide | Power
  deriving Eq

data SymbolicExpression =
  Variable SymbolicName |
  Exact ExactNumber |
  Inexact InexactNumber |
  Binary BinaryOperator SymbolicExpression SymbolicExpression |
  Apply SymbolicName [SymbolicExpression]
  deriving Eq


--  Instance declarations for SymbolicExpressions as numbers.

instance Show SymbolicExpression where
  show = prettyPrint

instance Num SymbolicExpression where
  (+) = Binary Add
  (-) = Binary Subtract
  (*) = Binary Multiply
  abs x = Apply "%abs" [x]
  signum x = Apply "%sign" [x]
  fromInteger = Exact

instance Fractional SymbolicExpression where
  (/) = Binary Divide
  fromRational r = Binary Divide (Exact $ Data.Ratio.numerator r)
                                 (Exact $ Data.Ratio.denominator r)

instance Floating SymbolicExpression where
  (**) = Binary Power
  pi = Variable "%pi"
  exp x = Apply "%exp" [x]
  log x = Apply "%log" [x]
  sin x = Apply "%sin" [x]
  cos x = Apply "%cos" [x]
  tan x = Apply "%tan" [x]
  asin x = Apply "%arcsin" [x]
  acos x = Apply "%arccos" [x]
  atan x = Apply "%arctan" [x]
  sinh x = Apply "%sinh" [x]
  cosh x = Apply "%cosh" [x]
  tanh x = Apply "%tanh" [x]
  asinh x = Apply "%arsinh" [x]
  acosh x = Apply "%arcosh" [x]
  atanh x = Apply "%artanh" [x]


--  Pretty printing expressions.
--  ----------------------------------------------------------------------------


--  Repeat a character a given number of times.

repeatChar :: Char -> Int -> String
repeatChar c cnt = repeatHelper c cnt ""
  where
    repeatHelper :: Char -> Int -> String -> String
    repeatHelper _ 0 res = res
    repeatHelper c cnt res = repeatHelper c (cnt - 1) (c : res)


--  Handling of "blocks" containing of multiple lines of text to represent
--  superscripts or fractions.
--
--  The blocks "base row" is numbered 0; -RowsFrom and RowsTo specify the number
--  of rows above / below it, respectively.

--  TextBlock RowsFrom RowsTo Columns [Rows]
data TextBlock = TextBlock Int Int Int [String]

instance Show TextBlock where
  show (TextBlock _ _ _ rows) = showRows rows
    where
      showRows :: [String] -> String
      showRows [] = ""
      showRows (h:t) = h ++ ('\n' : (showRows t))

blanks :: Int -> String
blanks = repeatChar ' '

lineBlock :: String -> TextBlock
lineBlock str = TextBlock 0 0 (length str) [str]


--  Put blocks together horizontally, padding with blank rows as needed.

blockConcat :: TextBlock -> TextBlock -> TextBlock
blockConcat b1@(TextBlock f1 t1 c1 rows1) b2@(TextBlock f2 t2 c2 rows2) =
  let fNew = (min f1 f2)
      tNew = (max t1 t2) in
    concatBlocks (addRows b1 fNew tNew) (addRows b2 fNew tNew)
  where
    addRows :: TextBlock -> Int -> Int -> TextBlock
    addRows b@(TextBlock f t c rows) fNew tNew
      | f > fNew  = addRows (TextBlock (f - 1) t c
                                       ((blanks c):rows))
                            fNew tNew
      | t < tNew  = addRows (TextBlock f (t + 1) c
                                       (rows ++ [blanks c]))
                            fNew tNew
      | otherwise = b
    concatBlocks (TextBlock f1 t1 c1 rows1) (TextBlock f2 t2 c2 rows2)
      | f1 == f2 && t1 == t2 =
        let zipped = zip rows1 rows2
            mapper = (\(a, b) -> a ++ b)
            resultrows = map mapper zipped in
          TextBlock f1 t1 (c1 + c2) resultrows
      | otherwise = error "Blocks don't have same number of rows!"


--  Put blocks together vertically, centering lines with blanks as needed.

centerLine :: Int -> String -> Int -> String
centerLine len str targetlen
  | len == targetlen = str
  | len < targetlen  = let diff = targetlen - len
                           left = div diff 2
                           right = diff - left in
                        (blanks left) ++ str ++ (blanks right)
  | otherwise        = error "centerLine can't make lines shorter..."

blockVConcat :: TextBlock -> TextBlock -> TextBlock
blockVConcat a@(TextBlock f1 t1 c1 rows1) b@(TextBlock f2 t2 c2 rows2) =
  let fulllen = max c1 c2
      firstPart = centerLines fulllen c1 rows1 []
      secondPart = centerLines fulllen c2 rows2 firstPart
      newRows = reverse secondPart
      newF = 0
      newT = newF + (t1 - f1 + 1) + (t2 - f2 + 1) - 1 in
    TextBlock newF newT fulllen newRows
  where
    centerLines :: Int -> Int -> [String] -> [String] -> [String]
    centerLines _ _ [] result = result
    centerLines fulllen partlen (h:t) result =
      centerLines fulllen partlen t ((centerLine partlen h fulllen) : result)


--  Put parentheses around a text-block.

parenthesize :: TextBlock -> TextBlock
parenthesize block@(TextBlock f t c rows) = TextBlock f t (c + 2) 
                                                      (parentRows rows [])
  where
    parentRows :: [String] -> [String] -> [String]
    parentRows [a] [] = ["(" ++ a ++ ")"]
    parentRows [a] result = reverse (("\\" ++ a ++ "/") : result)
    parentRows (h:t) [] = parentRows t ["/" ++ h ++ "\\"]
    parentRows (h:t) result = parentRows t (("|" ++ h ++ "|") : result)

--   Does a need parentheses when part of expression b?
needsParens :: SymbolicExpression -> SymbolicExpression -> Bool
needsParens (Variable _) _ = False
needsParens (Exact _) _ = False
needsParens (Inexact _) _ = False
needsParens _ (Apply _ _) = False
needsParens _ (Binary Divide _ _) = False
needsParens (Binary Add _ _) (Binary Add _ _) = False
needsParens (Binary Add _ _) (Binary Subtract _ _) = False
needsParens (Binary Add _ _) _ = True
needsParens (Binary Subtract _ _) (Binary Add _ _) = False
needsParens (Binary Subtract _ _) (Binary Subtract _ _) = False
needsParens (Binary Subtract _ _) _ = True
needsParens (Binary Multiply _ _) (Binary Power _ _) = True
needsParens (Binary Multiply _ _) _ = False
needsParens (Binary Divide _ _) (Binary Power _ _) = True
needsParens (Binary Divide _ _) _ = False
needsParens (Binary Power _ _) (Binary Power _ _) = True
needsParens (Binary Power _ _) _ = False
needsParens (Apply _ _) (Binary Power _ _) = True
needsParens (Apply _ _) _ = False

parenthesizedIfNeeded :: SymbolicExpression -> SymbolicExpression -> TextBlock
parenthesizedIfNeeded a b =
  if (needsParens a b)
    then parenthesize (prettyBlock a)
    else prettyBlock a


--  Pretty printing routines themselves.

prettyPrint :: SymbolicExpression -> String
prettyPrint x = show $ prettyBlock x

blockBasicOp :: String -> SymbolicExpression
                -> SymbolicExpression -> SymbolicExpression
                -> TextBlock
blockBasicOp op full a b =
  blockConcat (parenthesizedIfNeeded a full)
              (blockConcat (lineBlock $ " " ++ op ++ " ")
                           (parenthesizedIfNeeded b full))

prettyBlock :: SymbolicExpression -> TextBlock

prettyBlock (Variable var) = lineBlock var
prettyBlock (Exact i) = lineBlock (show i)
prettyBlock (Inexact f) = lineBlock (show f)
prettyBlock f@(Binary Add a b) = blockBasicOp "+" f a b
prettyBlock f@(Binary Subtract a b) = blockBasicOp "-" f a b
prettyBlock f@(Binary Multiply a b) = blockBasicOp "*" f a b

prettyBlock f@(Apply func args) =
  let argBlock = parenthesize (blockArgs args) in
    blockConcat (lineBlock $ func ++ " ") argBlock
  where
    blockArgs :: [SymbolicExpression] -> TextBlock
    blockArgs [] = lineBlock ""
    blockArgs [a] = parenthesizedIfNeeded a f
    blockArgs (a:b:c) = blockConcat (parenthesizedIfNeeded a f)
                                    (blockConcat (lineBlock ", ")
                                                 (blockArgs $ b:c))

prettyBlock f@(Binary Power a b) =
  let exponent@(TextBlock fe te ce rowse) = prettyBlock b
      base@(TextBlock fb tb cb rowsb) = parenthesizedIfNeeded a f
      moveup = (- fb) + 1
      movedexp = TextBlock (fe - moveup) (te - moveup) ce rowse in
    blockConcat base movedexp

prettyBlock f@(Binary Divide a b) =
  let numer@(TextBlock fn tn cn rowsn) = parenthesizedIfNeeded a f
      denom@(TextBlock fd td cd rowsd) = parenthesizedIfNeeded b f
      bar = lineBlock (repeatChar '-' ((max cn cd) + 2))
      full@(TextBlock ff tf cf rowsf) = blockVConcat numer
                                                     (blockVConcat bar denom)
      offset = ff - (tn - fn + 1) in
    TextBlock (ff + offset) (tf + offset) cf rowsf
