--  HasCas -- A symbolic manipulation package in Haskell.

--  Symbolic differentiation of expressions.  This is based on a set of
--  differentiation rules that can be asked to work on an expression.
--
--  Each rule has the interface:
--  Differentiator -> Variable -> Expression -> Maybe Expression
--
--  The first argument is there to allow recursive differentiation using the
--  given configuration for subexpressions.
--  A return of Nothing means it did not match, while Just Expression means that
--  Expression is the final result and no further search is necessary.

module HasCas.Differentiation (DiffRule(..), FunctionDerivatives(..),
                               Differentiator,
                               createDifferentiator,
                               addDiffRules, addFunctionDerivatives,
                               differentiate,
                               diffLibrary) where

import HasCas.Expression
import HasCas.Evaluation


--  DifferentiationRule and Differentiator.
--  ----------------------------------------------------------------------------


--  A rule to try for differentiating an expression.

data DiffRule =
  DiffRule (Differentiator -> SymbolicName -> SymbolicExpression
            -> Maybe SymbolicExpression)


--  Define the derivatives of a function.
--
--  This consists of the function name (of course), a list of formal
--  arguments and a list of expressions denoting the partial derivatives
--  of the function matched to the formal list.

data FunctionDerivatives =
  FunctionDerivatives SymbolicName [SymbolicName] [SymbolicExpression]


data Differentiator = Differentiator [DiffRule] [FunctionDerivatives]


createDifferentiator :: [DiffRule] -> Differentiator
createDifferentiator rules = Differentiator rules []

addDiffRules :: Differentiator -> [DiffRule] -> Differentiator
addDiffRules (Differentiator list fd) new = Differentiator (new ++ list) fd

addFunctionDerivatives :: Differentiator -> [FunctionDerivatives]
                          -> Differentiator
addFunctionDerivatives (Differentiator rules fd) newfd =
  Differentiator rules (newfd ++ fd)


differentiate :: Differentiator -> SymbolicName -> SymbolicExpression
                 -> SymbolicExpression
differentiate d@(Differentiator rules _) var expr = applyRules rules d var expr
  where
    applyRules :: [DiffRule] -> Differentiator -> SymbolicName
                  -> SymbolicExpression -> SymbolicExpression
    applyRules [] _ _ x = error $ "No diff-rule for expression:\n" ++ (show x)
    applyRules ((DiffRule h) : t) d var exp =
      let try = h d var exp in
        case try of
          Just a -> a
          Nothing -> applyRules t d var exp


--  Some differentiation rules themselves.
--  ----------------------------------------------------------------------------


--  Basic differentiation rules.

atomRule, sumRule, productRule, quotientRule, powerRule ::
  Differentiator -> SymbolicName -> SymbolicExpression
  -> Maybe SymbolicExpression

atomRule _ dvar exp =
  case exp of
    (Variable xvar) -> if (dvar == xvar)
                        then (Just (Exact 1))
                        else (Just (Exact 0))
    (Exact xvar) -> (Just (Exact 0))
    (Inexact xvar) -> (Just (Exact 0))
    otherwise -> Nothing

sumRule d dvar exp =
  case exp of
    (Binary Add a b) -> Just $ Binary Add (differentiate d dvar a)
                                          (differentiate d dvar b)
    (Binary Subtract a b) -> Just $ Binary Subtract (differentiate d dvar a)
                                                    (differentiate d dvar b)
    otherwise -> Nothing

productRule d dvar (Binary Multiply a b) =
  Just $ a * (differentiate d dvar b) + (differentiate d dvar a) * b
productRule _ _ _ = Nothing

quotientRule d dvar (Binary Divide a b) =
  Just $ (b * (differentiate d dvar a) - (differentiate d dvar b) * a) / b**2
quotientRule _ _ _ = Nothing

--  For powers, things get a little complicated; namely, for things like
--  x^x where both base and exponent depend on x, we have to rewrite it like
--  exp (ln x * x).  This works always, so just do it.
--
--  Note: To make sure we don't enter infinite recursion, we explicitelly
--  rewrite the exponential as "Apply %exp ..." instead of exp.
powerRule d dvar (Binary Power a b) =
  let rewritten = Apply "%exp" [log (a) * b] in
    Just $ differentiate d dvar rewritten
powerRule _ _ _ = Nothing


--  Differentiate functions via function derivative objects and the
--  chain rule.

functionRule :: Differentiator -> SymbolicName -> SymbolicExpression
                -> Maybe SymbolicExpression

functionRule d@(Differentiator rules fd) dvar (Apply func args) =
  let fdiv@(FunctionDerivatives _ formal derivs) = findFunctionDeriv fd func
      argBinding = createBindings $ bindArguments formal args []
      partDerivs = map (substVars argBinding) derivs
      innerDerivs = map (differentiate d dvar) args in
    Just $ sum $ zipWith (*) partDerivs innerDerivs
  where
    findFunctionDeriv :: [FunctionDerivatives] -> SymbolicName
                         -> FunctionDerivatives
    findFunctionDeriv lst name =
      let rest = dropWhile (\(FunctionDerivatives f _ _) -> f /= name) lst in
        if (length rest) == 0
          then error $ "No function derivation rule for " ++ name
          else head rest
    bindArguments :: [SymbolicName] -> [SymbolicExpression] -> [VariableBinding]
                     -> [VariableBinding]
    bindArguments [] [] result = result
    bindArguments [] _ _ =
      error $ "Too many actual arguments for function " ++ func
    bindArguments _ [] _ =
      error $ "Too few actual arguments for function " ++ func
    bindArguments (f1:t1) (f2:t2) result =
      bindArguments t1 t2 ((VariableBinding f1 f2) : result)

functionRule _ _ _ = Nothing


--  Derivatives of standard functions.

expDeriv, logDeriv, sinDeriv, cosDeriv, tanDeriv :: FunctionDerivatives
expDeriv = FunctionDerivatives "%exp" ["x"] [exp (Variable "x")]
logDeriv = FunctionDerivatives "%log" ["x"] [recip (Variable "x")]
sinDeriv = FunctionDerivatives "%sin" ["x"] [cos (Variable "x")]
cosDeriv = FunctionDerivatives "%cos" ["x"] [negate $ sin (Variable "x")]
tanDeriv = FunctionDerivatives "%tan" ["x"] [recip $ (cos $ Variable "x") ** 2]


--  Put it all together for a standard Differentiator.

diffLibrary :: Differentiator
diffLibrary =
  let rules = createDifferentiator [DiffRule atomRule,
                                    DiffRule sumRule,
                                    DiffRule productRule,
                                    DiffRule quotientRule,
                                    DiffRule powerRule,
                                    DiffRule functionRule] in
    addFunctionDerivatives rules [expDeriv, logDeriv,
                                  sinDeriv, cosDeriv, tanDeriv]
