-----------------------------------------------------------------------------
-- |
-- Module      :  Test.QuickCheck.Parallel
-- Copyright   :  (c) Don Stewart 2006
-- License     :  BSD-style (see the file LICENSE)
--
-- Maintainer  :  dons@cse.unsw.edu.au
-- Stability   :  experimental
-- Portability :  non-portable (uses Control.Exception, Control.Concurrent)
--
-- A parallel batch driver for running QuickCheck on threaded or SMP systems.
-- See the /Example.hs/ file for a complete overview.
--

module Test.QuickCheck.Parallel (
                                 module Test.QuickCheck,
                                 pRun,
                                 pDet,
                                 pNon ) where

import Test.QuickCheck
import Data.List
import Control.Concurrent
import System.Random
import System.IO          (hFlush,stdout)
import Text.Printf

type Name   = String
type Depth  = Int
type Test   = (Name, Depth -> IO String)

-- | Run a list of QuickCheck properties in parallel chunks, using
-- 'n' Haskell threads (first argument), and test to a depth of 'd'
-- (second argument). Compile your application with '-threaded' and run
-- with the SMP runtime's '-N4' (or however many OS threads you want to
-- donate), for best results.
--
-- > import Test.QuickCheck.Parallel
-- >
-- > do n <- getArgs >>= readIO . head
-- >    pRun n 1000 [ ("sort1", pDet prop_sort1) ]
--
-- Will run 'n' threads over the property list, to depth 1000.
--
pRun :: Int -> Int -> [Test] -> IO ()
pRun n depth tests = do
    chan <- newChan
    ps   <- getChanContents chan
    work <- newMVar tests

    forM_ [1..n] $ forkIO . thread work chan

    let wait xs i
            | i >= n    = return () -- done
            | otherwise = case xs of
                    Nothing : ys -> wait ys $! i+1
                    Just s  : ys -> putStr s >> hFlush stdout >> wait ys i
    wait ps 0

  where
    thread :: MVar [Test] -> Chan (Maybe String) -> Int -> IO ()
    thread work chan me = loop
      where
        loop = do
            job <- modifyMVar work $ \jobs -> return $ case jobs of
                        []     -> ([], Nothing)
                        (j:js) -> (js, Just j)
            case job of
                Nothing          -> writeChan chan Nothing -- done
                Just (name,prop) -> do
                    v <- prop depth
                    writeChan chan . Just $ printf "%d: %-25s: %s" me name v
                    loop

-- | Wrap a property, and run it on a deterministic set of data
pDet :: Testable a => a -> Int -> IO String
pDet a n = mycheck Det defaultConfig
    { configMaxTest = n
    , configEvery   = \_ args -> unlines args } a

-- | Wrap a property, and run it on a non-deterministic set of data
pNon :: Testable a => a -> Int -> IO String
pNon a n = mycheck NonDet defaultConfig
    { configMaxTest = n
    , configEvery   = \_ args -> unlines args } a

data Mode = Det | NonDet

------------------------------------------------------------------------

mycheck :: Testable a => Mode -> Config -> a -> IO String
mycheck Det config a = do
     let rnd = mkStdGen 99  -- deterministic
     mytests config (evaluate a) rnd 0 0 []

mycheck NonDet config a = do
    rnd <- newStdGen        -- different each run
    mytests config (evaluate a) rnd 0 0 []

mytests :: Config -> Gen Result -> StdGen -> Int -> Int -> [[String]] -> IO String
mytests config gen rnd0 ntest nfail stamps
  | ntest == configMaxTest config = do done "OK," ntest stamps
  | nfail == configMaxFail config = do done "Arguments exhausted after" ntest stamps
  | otherwise = do
         case ok result of
           Nothing    ->
             mytests config gen rnd1 ntest (nfail+1) stamps
           Just True  ->
             mytests config gen rnd1 (ntest+1) nfail (stamp result:stamps)
           Just False ->
             return ( "Falsifiable after "
                   ++ show ntest
                   ++ " tests:\n"
                   ++ unlines (arguments result)
                    )
     where
      result      = generate (configSize config ntest) rnd2 gen
      (rnd1,rnd2) = split rnd0

done :: String -> Int -> [[String]] -> IO String
done mesg ntest stamps =
    return ( mesg ++ " " ++ show ntest ++ " tests" ++ table )
  where
    table = display
        . map entry
        . reverse
        . sort
        . map stringPairLength
        . group
        . sort
        . filter (not . null)
        $ stamps

    display []  = ".\n"
    display [x] = " (" ++ x ++ ").\n"
    display xs  = ".\n" ++ unlines (map (++ ".") xs)

    -- Returns "" on empty list
    stringPairLength :: [[String]] -> (Int, [String])
    stringPairLength [] = (0, [""])
    stringPairLength xs =  (length xs, head xs)

    entry (n, xs)         = percentage n ntest
                          ++ " "
                          ++ concat (intersperse ", " xs)

    percentage n m        = show ((100 * n) `div` m) ++ "%"

forM_ :: (Monad m) => [a] -> (a -> m b) -> m ()
forM_ = flip mapM_
