{-# OPTIONS -O2 -optc-O -fglasgow-exts -optc-march=pentium4 #-}
{-# LANGUAGE BangPatterns #-}

import Text.Printf
import Control.Exception
import System.CPUTime
import System.IO

import Control.Monad.ST
import System
import Control.Monad
import Data.Bits
import Text.Printf
import Data.Array.Vector.ST

import Data.Array.Base
import GHC.Exts
import GHC.ST

------------------------------------------------------------------------

time :: IO t -> IO Double
time a = do
    start <- getCPUTime
    !v <- a
    end   <- getCPUTime
    let diff = (fromIntegral (end - start)) / (10^12)
    return diff

main = do
    putStrLn "Starting..."
    mapM_ run
         [ ("nsieve-bits", time_nsieve 12)

         ]
    putStrLn "Done."

run (s, a) = do
    putStr (s++": ") >> hFlush stdout
    t <- a
    if t then do putStrLn "Ok."
         else do putStrLn "Fail! New code was slower."
                 exitWith (ExitFailure 1)

------------------------------------------------------------------------
-- bitwise prime sive

time_nsieve n = do
    !x <- (time (nsieve1 n))
    !y <- (time (nsieve2 n))
    return (x < y)

  where

    ------------------------------------------------------------------------
    -- PROGRAM 1

    nsieve1 n = mapM_ (\i -> sieve1 (10000 `shiftL` (n-i))) [0, 1, 2]

    sieve1 n = do
       let r = runST (do t <- new n True
                         go t n 2 0)
       n `seq` r `seq` return ()

    go !a !m !n !c
        | n == m    = return c
        | otherwise = do
              e <- get a n
              if e then let loop j
                              | j < m    = do
                                  x <- get a j
                                  when x $ set a j False
                                  loop (j+n)
                              | otherwise = go a m (n+1) (c+1)
                        in loop (n `shiftL` 1)
                   else go a m (n+1) c

{-
    {-# INLINE newArrayT #-}
    newArrayT n@(I# n#) t = ST $ \s1# ->
        case newByteArray# (bOOL_SCALE n#) s1# of { (# s2#, marr# #) ->
        case bOOL_WORD_SCALE n#         of { n'# ->
        let loop i# s3# | i# ==# n'# = s3#
                        | otherwise  =
                case writeWordArray# marr# i# e# s3# of { s4# ->
                loop (i# +# 1#) s4# } in
        case loop 0# s2#                of { s3# ->
        (# s3#, STUVector n marr# #) }}}
      where
        W# e# = if t then maxBound else 0 -- True
-}

    ------------------------------------------------------------------------
    -- PROGRAM 2

    nsieve2 n = mapM_ (\i -> sieve2 (10000 `shiftL` (n-i))) [0, 1, 2]

    sieve2 n = do
       let r = runST (do a <- newArray (2,n) True :: ST s (STUArray s Int Bool)
                         go2 a n 2 0)
       n `seq` r `seq` return ()

    go2 !a !m !n !c
        | n == m    = return c
        | otherwise = do
              e <- unsafeRead a n
              if e then let loop j
                              | j < m     = do
                                  x <- unsafeRead a j
                                  when x $ unsafeWrite a j False
                                  loop (j+n)

                              | otherwise = go2 a m (n+1) (c+1)
                        in loop (n `shiftL` 1)
                   else go2 a m (n+1) c


------------------------------------------------------------------------
