module Main where
import Control.Applicative
import Control.Arrow
import Control.Monad
import Data.List
import System.FilePath
import System.Directory
import System.Environment
import System.IO
import System.Exit
import System.Posix.Process
import System.Posix.Types
import System.Process
import Data.Function
import Data.Ord

secondM :: (Functor f) => (a -> f b) -> (d, a) -> f (d, b)
secondM s = uncurry(fmap . (,)) . second s

zipWith' :: (a -> a -> a) -> [a] -> [a] -> [a]
zipWith' _ [] xs = xs
zipWith' _ xs [] = xs
zipWith' f (x:xs) (y:ys) = f x y : zipWith' f xs ys

numTests :: Int
numTests = 5

splitOn :: String -> String -> (String,String)
splitOn x s = first reverse (go [] s)
    where go h t@(t1:ts) = case stripPrefix x t of
                 Just t' -> (h,t')
                 _ -> go (t1:h) ts
          go h _ = (h,[])

readSrc :: String -> [String]
readSrc [] = []
readSrc s = uncurry (:) . (second readSrc) . splitOn "{- ! -}" $ s

writeSrc :: [String] -> [Bool] -> String
writeSrc (s:ss) (True:bs) = s ++ "!" ++ writeSrc ss bs
writeSrc (s:ss) (False:bs) = s ++ writeSrc ss bs
writeSrc ss _ = concat ss

makeInits :: [String] -> [([Bool], String)]
makeInits src = map (second (writeSrc src)) $ ([False],[False]) : map (join (,) . \x->replicate x False ++ [True]) [0..(length src - 2)]

testFile :: FilePath -> [String] -> String -> IO ClockTick
testFile tmpPath args s = do
  writeFile tmpFile s
  hPutStrLn stderr ""
  dn <- Just <$> openFile "/dev/null" WriteMode
  ecd <- waitForProcess =<< runProcess "ghc" ["--make", "-O2", "-fbang-patterns", tmpFile] Nothing Nothing Nothing dn dn
  case ecd of ExitSuccess -> sum . tail . tail <$> replicateM numTests timeCmd; _ -> return 100000000
      where tmpFile = tmpPath </> "tmp.hs"
            timeCmd = do
                    t <- childUserTime <$> getProcessTimes
                    pid <- forkProcess $ executeFile ("/bin/sh") False ["-c", tmpPath </> "tmp " ++ intercalate " " args ++ " < " ++ tmpPath </> "input" ++ " > /dev/null"] Nothing
                    hPutStr stderr "."
                    ps <- getProcessStatus True True pid
                    case ps of
                      Just (Exited ExitSuccess) -> (\z -> z - t) . childUserTime <$> getProcessTimes
                      _ -> hPutStr stderr " Test failed." >> return 100000000

addFiles :: FilePath -> [String] -> [String] -> ([Bool],ClockTick) -> ([Bool],ClockTick) -> IO ([Bool],ClockTick)
addFiles tmpPath args src x@(bs,_) y@(bs',_) = do
  t'' <- testFile tmpPath args $ writeSrc src bs''
  return $ minimumBy (comparing snd) [(bs'',t''),x,y]
      where bs'' = zipWith' (||) bs bs'

usage :: String
usage =  "\nusage: strictify _file_ [Arguments]\n\n\
         \Example: strictify foo.hs 400 < infile > fooStrict.hs\n\n\
         \The above command determines a locally optimal combination of\n\
         \strictness annotations (hinted to strictify by {- ! -} comments)\n\
         \such that the executable generated by GHC --make -O2 foo.hs runs\n\
         \in minimal time when presented with 400 as an argument and input\n\
         \as derived from infile. The command as presented above then pipes\n\
         \the result to fooStrict.hs\n"

main :: IO ()
main = do
  argums <- getArgs
  let (infile,args) = case argums of
                        (a:bs) -> (a,bs)
                        _          -> error usage
  tmpPath <- (</> "strictify-working") <$> getTemporaryDirectory
  createDirectoryIfMissing False tmpPath
  src <- readSrc <$> readFile infile
  writeFile (tmpPath </> "input") =<< getContents
  hSetBuffering stderr NoBuffering
  hPutStr stderr $ show (length src - 1) ++ " candidate bangs."
  (b:bs) <- mapM (secondM (testFile tmpPath args)) . makeInits $ src
  let filtered = filter ((< snd b) . snd) bs
  hPutStrLn stderr $ "\n" ++ show (length filtered) ++ " bang[s] in round two."
  case filtered of
    (f:fs) -> do
         newSrc <- fst <$> foldM (addFiles tmpPath args src) f fs
         hPutStrLn stderr $ show (length . filter id $ newSrc) ++ " bangs in the winner."
         putStrLn $ writeSrc src newSrc
    _ -> putStrLn . writeSrc src . fst $ b
  hSetBuffering stderr (BlockBuffering Nothing)
  removeDirectoryRecursive tmpPath
  hPutStrLn stderr $ "\n...finished."