{- |
  Parametrisation of orthogonal filter banks.
-}
module DiscreteWavelet.Lattice where

import qualified Signal        as Sig
import qualified ShiftedSignal as SSig
import Useful (sliceHoriz, toMaybe)
import Data.Maybe (catMaybes)

type T a = ([(a,a)],a)

{- |
  Compute the rotation parameters and the corresponding error.
  Both operands must have the same length.
-}
decomposeWithResidue :: Num a => [a] -> [a] -> [((a,a),a)]
decomposeWithResidue [] [] = []
decomposeWithResidue [] _ = error "first filter too short"
decomposeWithResidue _ [] = error "second filter too short"
decomposeWithResidue he@(rc:_) ho@(rs:_) =
   let (he',ho') = rotate (rc,rs) (he,ho)
       err  = last he'
       he'' = init he'
       (_:ho'') = ho'
   in  ((rc,rs),err) : decomposeWithResidue he'' ho''

decomposePlain :: Num a => [a] -> T a
decomposePlain h =
   let [he,ho] = sliceHoriz 2 h
       lat = decomposeWithResidue he ho
   in  (map fst lat, snd (last lat))


{- |
  The analysis transform with respect to the lattice coefficients.
-}
analyse :: Fractional a => T a -> SSig.T a -> (SSig.T a, SSig.T a)
analyse (lat, amp) x =
   let [xe,xo] = SSig.sliceHoriz 2 x
       (yl,yh) =
          foldl (\y (rc,rs) -> ssLatticeGate 1 (-rc,rs) y)
                (SSig.amplify amp xe, SSig.amplify amp xo)
                lat
   in  (yl, SSig.translate ( - length lat) yh)

{- |
  The synthesis transform with respect to the lattice coefficients.
  The zipping of the bands is a bit inefficient because of a lot summands are zero.
-}
synthesise :: Fractional a => T a -> (SSig.T a, SSig.T a) -> SSig.T a
synthesise (lat, amp) (yl,yh) =
   let (xe,xo) =
          foldr (ssLatticeGate (-1))
                (SSig.amplify amp yl,
                 SSig.amplify amp (SSig.translate (length lat - 1) yh))
                lat
   in  SSig.superpose
          (SSig.upSample 2 xe)
          (SSig.translate (-1) (SSig.upSample 2 xo))
          -- compensate the additional shift of the last application of ssLatticeGate

{- |
  Check if the polyphase matrix induced by a lattice filterbank
  has the special structure,
  that is it contains half filters in normal and adjoint version.
-}
propAnalyseOrth :: Fractional a => [(a,a)] -> Bool
propAnalyseOrth lat =
   let latC = (filter ((0,0) /=) lat, 1)
       (he,hoAdj) = analyse latC (SSig.Cons 0 [1])
       (ho,heAdj) = analyse latC (SSig.Cons 1 [1])
   in  he == SSig.reverse heAdj  &&
       ho == SSig.neg (SSig.reverse hoAdj)

propAnalyseOrthInt :: [(Integer,Integer)] -> Bool
propAnalyseOrthInt lat =
   propAnalyseOrth (map ratFromIntPair lat)

{- |
  Contruct an orthogonal filter pair by the lattice analysis transform
  and check whether it decomposes to the same lattice coefficients.
-}
propAnalyseDecompose :: Fractional a => [(a,a)] -> Bool
propAnalyseDecompose lat =
   let -- prepend a (1,0) in order to avoid empty lists
       latNoZero = (1,0) : filter ((0,0) /=) lat
       latC = (latNoZero, 1)
       he = fst (analyse latC (SSig.Cons 0 [1]))
       ho = fst (analyse latC (SSig.Cons 1 [1]))
       latNew = decomposeWithResidue (SSig.signal he) (SSig.signal ho)
       sim (a0,a1) (b0,b1) =
          toMaybe (a0*b1==a1*b0) (if b0/=0 then a0/b0 else a1/b1)
       simFacs = zipWith sim (map fst latNew) latNoZero
   in  not (any (Nothing ==) simFacs) &&
       product (catMaybes simFacs) == snd (last latNew)
       -- (map fst latNew, snd (last latNew)) == latC

propAnalyseDecomposeInt :: [(Integer,Integer)] -> Bool
propAnalyseDecomposeInt lat =
   propAnalyseDecompose (map ratFromIntPair lat)


{- |
  Check whether the synthesis transform inverts the analysis transform.
-}
propAnalyseSynthesise :: Fractional a => [(a,a)] -> Bool
propAnalyseSynthesise lat =
   let lat1 = (lat, 1)
   in  SSig.isImpulse (synthesise lat1
          (analyse lat1 (SSig.fromScalar 1)))

propAnalyseSynthesiseInt :: [(Integer,Integer)] -> Bool
propAnalyseSynthesiseInt lat =
   propAnalyseSynthesise (map ratFromIntPair lat)


ratFromIntPair :: (Integer,Integer) -> (Rational,Rational)
ratFromIntPair (x,y) = (fromInteger x, fromInteger y)

rotate :: Num a => (a, a) -> ([a], [a]) -> ([a], [a])
rotate (rc,rs) (xe,xo) =
   (linearComb2   rc  xe rs xo,
    linearComb2 (-rs) xe rc xo)

linearComb2 :: Num a => a -> [a] -> a -> [a] -> [a]
linearComb2 a av b bv =
   Sig.superpose
      (Sig.amplify a av)
      (Sig.amplify b bv)

ssLatticeGate :: Fractional a =>
   Int -> (a,a) -> (SSig.T a, SSig.T a) -> (SSig.T a, SSig.T a)
ssLatticeGate shift (rc,rs) x =
   let (ye,yo) = ssRotate (-rc,rs) x
       c       = recip (rc*rc+rs*rs)
   in  (SSig.amplify c ye, SSig.translate shift (SSig.amplify c yo))

ssRotate :: Num a => (a, a) -> (SSig.T a, SSig.T a) -> (SSig.T a, SSig.T a)
ssRotate (rc,rs) (xe,xo) =
   (ssLinearComb2   rc  xe rs xo,
    ssLinearComb2 (-rs) xe rc xo)

ssLinearComb2 :: Num a => a -> SSig.T a -> a -> SSig.T a -> SSig.T a
ssLinearComb2 a av b bv =
   SSig.superpose
      (SSig.amplify a av)
      (SSig.amplify b bv)
