module SymbolicDifferentiation where

import Data.Function.HT (nest, )

{- These rules need Num context. -- How to write that?
     "derive/const"  forall (y ::  (Num a) => a).   derive (const y) = const 0 ;
     "derive/const"  forall y.   derive (const y) = const 0 ;
     "derive/id"                 derive id        = const 1 ;
     "derive/compos" forall f g. derive (f .  g)  = derive f . g .* derive g ;
-}

{-# RULES
     "derive/plus"   forall f g. derive (f .+ g) = derive f .+ derive g ;
     "derive/minus"  forall f g. derive (f .- g) = derive f .- derive g ;
     "derive/times"  forall f g. derive (f .* g) = derive f .* g .+ f .* derive g ;
     "derive/divide" forall f g. derive (f ./ g) = (derive f .* g .- f .* derive g) ./ ((^(2::Integer)).g) ;
     "derive/power"  forall n.   derive (flip (^) n) = (n*) . (^ pred n) ;
     "derive/negate" forall f.   derive (negate . f) = negate . derive f ;
     "derive/sin"                derive sin  = cos ;
     "derive/cos"                derive cos  = negate . sin ;
     "derive/exp"                derive exp  = exp ;
     "derive/log"                derive log  = recip ;
     "derive/abs"                derive abs  = signum ;
  #-}

-- lift a binary operation to the function values
fop2 :: (c -> d -> e) -> (a -> c) -> (a -> d) -> (a -> e)
fop2 op f g x = op (f x) (g x)

infixl 6 .+, .-
infixl 7 .*, ./

(.+), (.-), (.*) :: Num a        => (t -> a) -> (t -> a) -> (t -> a)
(./)             :: Fractional a => (t -> a) -> (t -> a) -> (t -> a)
(.+) = fop2 (+)
(.-) = fop2 (-)
(.*) = fop2 (*)
(./) = fop2 (/)

derive :: (t -> a) -> (t -> a)
derive = error "Could not derive expression symbolically."

test :: IO ()
test = mapM_ (\(msg,val) -> putStrLn (msg ++ " = " ++ show (val::Double))) $
   ("log' 2"               , derive log 2) :
   ("abs' pi"              , derive abs pi) :
   ("sin' 0"               , derive sin 0) :
   ("cos' 0.1"             , derive cos 0.1) :
   ("cos'' 0"              , derive (derive cos) 0) :
   ("(cos .+ sin)' (pi/4)" , derive (cos .+ sin) (pi/4)) :
   ("exp' 0"               , derive (\x -> exp x) 0) :
   ("(exp . sin)' 0"       , derive (exp . sin) 0) :
   ("(\\x -> x^2 + x)' 0"  , derive (\x -> x^(2::Integer) + x) 0) :
   ("(^3)' 2"              , derive (^(3::Integer)) 2) :
   ("(^3)' 2"              , derive (flip (^) (3::Integer)) 2) :
   ("cos''' 0"             , nest 3 derive cos 0) :
   []
