{-# OPTIONS -fglasgow-exts #-}
{-# LANGUAGE UndecidableInstances, ScopedTypeVariables #-}

-- Big comment:

-- Note that the terms "reify" and "reflect" may seem to be used in a way
-- which is the reverse of what one might intuit! Here, the
-- representation of a number in type-level binary is considered "real". 
-- Thus to turn a value into a type-level term is to "reify" it; to go
-- the other way is to "reflect".

module Fu.Prepose where

import System.IO.Unsafe       (unsafePerformIO)
import Control.OldException      (handle, handleJust, errorCalls)
import Foreign.Marshal.Utils  (with, new)
import Foreign.Marshal.Array  (peekArray, pokeArray)
import Foreign.Marshal.Alloc  (alloca)
import Foreign.Ptr            (Ptr, castPtr)
import Foreign.Storable       (Storable(sizeOf, peek))
import Foreign.C.Types        (CChar)
import Foreign.StablePtr      (StablePtr, newStablePtr,
                               deRefStablePtr, freeStablePtr)
import Data.Bits              (Bits(..))
import Prelude hiding         (getLine)

import Debug.Trace

newtype Modulus s a  =  Modulus a  deriving (Eq, Show)
newtype M s a        =  M a        deriving (Eq, Show)

unM :: M s a -> a
unM (M a) = a

__ = __

class Modular s a | s -> a where modulus :: s -> a

normalize :: forall a s. (Modular s a, Integral a) => a -> M s a
normalize a = M (mod a (modulus (__ :: s)))

instance (Modular s a, Integral a) => Num (M s a) where
  M a + M b      =  normalize (a + b)
  M a - M b      =  normalize (a - b)
  M a * M b      =  normalize (a * b)
  negate (M a)   =  normalize (negate a)
  fromInteger i  =  normalize (fromInteger i)
  signum         =  error "Modular numbers are not signed"
  abs            =  error "Modular numbers are not signed"

aaa x@(M a) y@(M b) = normalize (a + b) `asTypeOf` x
    where _ = [x,y]  -- make sure input moduli are equal

withModulus :: a -> (forall s. Modular s a => s -> w) -> w

data Zero; data Twice s; data Succ s; data Pred s

data Positive s; data Negative s; data TwiceSucc s

class ReflectUnsigned s where reflectUnsigned :: Num a => s -> a
instance                        ReflectUnsigned Zero where
  reflectUnsigned _  =  0
instance ReflectUnsigned s  =>  ReflectUnsigned (Twice s) where
  reflectUnsigned _  =  reflectUnsigned (__ :: s) * 2
instance ReflectUnsigned s  =>  ReflectUnsigned (TwiceSucc s) where
  reflectUnsigned _  =  reflectUnsigned (__ :: s) * 2 + 1

instance ReflectUnsigned s  =>  ReflectNum (Positive s) where
  reflectNum _       =  reflectUnsigned (__ :: s)
instance ReflectUnsigned s  =>  ReflectNum (Negative s) where
  reflectNum _       =  -1 - reflectUnsigned (__ :: s)


class ReflectNum s where reflectNum :: Num a => s -> a
instance                   ReflectNum Zero where
  reflectNum _  =  0
instance ReflectNum s  =>  ReflectNum (Twice s) where
  reflectNum _  =  reflectNum (__ :: s) * 2
instance ReflectNum s  =>  ReflectNum (Succ s) where
  reflectNum _  =  reflectNum (__ :: s) + 1
instance ReflectNum s  =>  ReflectNum (Pred s) where
  reflectNum _  =  reflectNum (__ :: s) - 1


reifyIntegral  ::  Integral a =>
                     a -> (forall s. ReflectNum s => s -> w) -> w
reifyIntegral i k = (trace $ "reifyIntegral "++show i) $ case quotRem i 2 of
  (0,   0) -> k (__ :: Zero)
  (j,   0) -> reifyIntegral j (\(_ :: s) -> k (__ :: Twice s))
  (j,   1) -> reifyIntegral j (\(_ :: s) -> k (__ :: Succ (Twice s)))
  (j,-  1) -> reifyIntegral j (\(_ :: s) -> k (__ :: Pred (Twice s)))


data ModulusNum s a

instance  (ReflectNum s, Num a) =>
            Modular (ModulusNum s a) a where
  modulus _ = reflectNum (__ :: s)


withIntegralModulus  ::  Integral a =>
                         a -> (forall s. Modular s a => s -> w) -> w
withIntegralModulus i k =
  reifyIntegral i (\(_ :: s) -> k (__ :: ModulusNum s a))


data Nil; data Cons s ss

class ReflectNums ss where reflectNums :: Num a => ss -> [a]
instance  ReflectNums Nil where
  reflectNums _ = []
instance  (ReflectNum s, ReflectNums ss) =>
            ReflectNums (Cons s ss) where
  reflectNums _ = reflectNum (__ :: s) : reflectNums (__ :: ss)

reifyIntegrals  ::  Integral a =>
                      [a] -> (forall ss. ReflectNums ss => ss -> w) -> w
reifyIntegrals []      k  =  k (__ :: Nil)
reifyIntegrals (i:ii)  k  =  reifyIntegral i (\(_ :: s) ->
                               reifyIntegrals ii (\(_ :: ss) ->
                                 k (__ :: Cons s ss)))


type Byte   = CChar

data Store s a

class ReflectStorable s where
  reflectStorable :: Storable a => s a -> a
instance ReflectNums s => ReflectStorable (Store s) where
  reflectStorable _  =  unsafePerformIO
                     $  alloca
                     $  \p -> do  pokeArray (castPtr p) bytes
                                  peek p
    where bytes  =  reflectNums (__ :: s) :: [Byte]

reifyStorable  ::  Storable a =>
                      a -> (forall s. ReflectStorable s => s a -> w) -> w
reifyStorable a k =
  reifyIntegrals (bytes :: [Byte]) (\(_ :: s) -> k (__ :: Store s a))
    where bytes  =  unsafePerformIO
                 $  with a (peekArray (sizeOf a) . castPtr)


class Reflect s a | s -> a where reflect :: s -> a

data Stable (s :: * -> *) a

instance ReflectStorable s => Reflect (Stable s a) a where
  reflect  =   unsafePerformIO
           $   do  a <- deRefStablePtr p
                   return (const a)
    where p = reflectStorable (__ :: s p)

reify :: a -> (forall s. Reflect s a => s -> w) -> w
reify (a :: a) k  =  unsafePerformIO
                  $  do  p <- newStablePtr a
                         reifyStorable p   (\(_ :: s (StablePtr a)) ->
                                              k' (__ :: Stable s a))
    where k' s = return (k s)


data ModulusAny s

instance Reflect s a => Modular (ModulusAny s) a where
  modulus _ = reflect (__ :: s)

withModulus a k = reify a (\(_ :: s) -> k (__ :: ModulusAny s))


withIntegralModulus'  :: forall w a.  Integral a =>
                          a -> (forall s. Modular s a => M s w) -> w
withIntegralModulus' i k =
  reifyIntegral i (  \(_ :: t) ->
                       unM (k :: M (ModulusNum t a) w))

test4'  ::  (Modular s a, Integral a) => M s a
test4'  =   3*3 + 5*5
		       
test4   =   withIntegralModulus' 4 test4'

data Even p q u v a = E a a deriving (Eq, Show)

normalizeEven :: forall a p q u v. (  ReflectNum p, ReflectNum q, Integral a,
                      Bits a) => a -> a -> Even p q u v a
normalizeEven a b  =
  E  (a .&. (shiftL 1 (reflectNum (__ :: p)) - 1))   -- $a \bmod 2^p$
     (mod b (reflectNum (__ :: q)))                  -- $b \bmod q$

instance (  ReflectNum p, ReflectNum q,
            ReflectNum u, ReflectNum v,
            Integral a, Bits a) => Num (Even p q u v a) where
    E a1 b1  +   E a2 b2  =  normalizeEven  (a1  +  a2)  (b1  +  b2)
                          {-"\raisebox{0pt}[2.5ex][0pt]{$\vdots$}"-}


    E a1 b1  -   E a2 b2  =  normalizeEven  (a1  -  a2)  (b1  -  b2)
    E a1 b1  *   E a2 b2  =  normalizeEven  (a1  *  a2)  (b1  *  b2)
    negate (E a b)        =  normalizeEven  (-a)         (-b)
    fromInteger i         =  normalizeEven  (fromInteger i)
                                            (fromInteger i)
    signum                =  error "Modular numbers are not signed"
    abs                   =  error "Modular numbers are not signed"

gcd' x 0 = []
gcd' x y = (x `div` y) : (gcd' y (x `mod` y))

-- chinese remainder theorem
crt a b =
    let (c,d) = snd $ foldl (\ ((a,b),(a',b')) p -> ((-(a*p+a'),b*p+b'),(-a,b)))
                ((1,0),(0,1)) (gcd' a b)
    in (-c*b,d*a)

withIntegralModulus'' ::  (Integral a, Bits a) =>
                            a -> (forall s. Num (s a) => s a) -> a
withIntegralModulus'' (i::a) k = case factor 0 i of
    (0,  i)  ->  withIntegralModulus' i k             -- odd case
    (p,  q)  ->  let (u, v) = crt (2^p) q in                -- even case: $i = 2^p q$
                    trace ("(u,v)="++show(u,v)) $
                   reifyIntegral p  (\(_::p  ) ->
                   reifyIntegral q  (\(_::q  ) ->
                   reifyIntegral u  (\(_::u  ) ->
                   reifyIntegral v  (\(_::v  ) ->
                   unEven (k :: Even p q u v a)))))

factor :: (Num p, Integral q) => p -> q -> (p, q)
factor p i = case quotRem i 2 of
    (0,  0)  ->  (0, 0)          -- just zero
    (j,  0)  ->  factor (p+1) j  -- accumulate powers of two
    _        ->  (p, i)          -- not even

unEven ::(  ReflectNum p, ReflectNum q, ReflectNum u,
  ReflectNum v, Integral a, Bits a) => Even p q u v a -> a
unEven (E a b :: Even p q u v a) = trace "unEven" $
  mod  (a * (reflectNum (__ :: u)) + b * (reflectNum (__ :: v)))
       (shiftL (reflectNum (__ :: q)) (reflectNum (__ :: p)))


test5  ::  Num (s a) => s a
test5  =   1000*1000 + 513*513

test5'   =  withIntegralModulus'' 1279 test5 :: Integer
test5''  =  withIntegralModulus'' 1280 test5 :: Integer

