""" reports modules
"""

import os
import platform
import re
import time
from sys import maxint

from geniconfig import Statistics
from MiscFunctions import which, systemOrDie, systemOrWarn, natsort
from octavePlots import *
from gtesterrc import *

texHline=r'\hline''\n'
texVspaceHalf=r'\vspace{0.5cm}''\n'
texBeginCenter=r'\begin{center}''\n'
texEndCenter=r'\end{center}''\n'

def hasKeyPair(dict, t, k):
  return dict.has_key(t) and dict[t].has_key(k)

def keyPairValOrDefault(dict, t, k, default="?"):
  if hasKeyPair(dict, t, k):
     return dict[t][k]
  else:
     return default

def divOrNa(x, y, places=1):
    if float(y) == 0: return "n/a"
    else:
      fmtStr = "%%.%sf" % places
      return (fmtStr % (float(x)/float(y)))

def percentageStr(x, y):
    if float(x) == 0: return "n/a"
    else: return ("%.2f" % (100 * (float(x)-float(y))/float(x)))

class Report:
    def __init__(self, params):
        self.params    = params
        self.__latexSections = []
        self.__octaveScripts = []
        self.__octaveBin = which('octave')
        self.__latexBin = which('latex')
        self.__dvipdfBin = which('dvipdf')
        self.__timestamp = time.strftime("%Y-%m-%d %H:%M")

    def generateFullReport(self):
        self.generateHeaderFooters()
        self.generateTestDescription()
        self.generateResponsesSummary()
        self.generateResponsesAnalysis()
        self.generateLatexDocument()

    def canCompile(self):
        return self.__octaveBin and self.__latexBin

    def compile(self):
        for octaveScript in self.__octaveScripts:
            systemOrWarn(self.__octaveBin, [octaveScript], stdout = None)

        docPrefix = '%s_results' % self.params.testId
        for i in range(3):
            systemOrDie(self.__latexBin, ['-interaction=nonstopmode', docPrefix], stdout = None)
            systemOrWarn(self.__dvipdfBin, [docPrefix], stdout = None)

    def generateLatexDocument(self):
        testId = self.params.testId
        suiteName = self.params.suiteName
        f = open('%s_results.tex' % testId, 'w')
        f.write(r'\documentclass{article}' '\n\n')
        f.write(r'\usepackage{a4wide}' '\n')
        f.write(r'\usepackage{graphicx}' '\n')
        f.write(r'\usepackage{color}' '\n')
        f.write(r'\usepackage{fancyhdr}' '\n')
        f.write(r'\usepackage{longtable}' '\n')
        f.write(r'\usepackage[latin1]{inputenc}' '\n\n')
        f.write(r'\title{%s %s\\Test results}' '\n\n' % (self.__latexEsc(testId), self.__latexEsc(suiteName)))
        f.write(r'\begin{document}' '\n')
        f.write(r'\maketitle' '\n')
        for (section, useInclude) in self.__latexSections:
            if useInclude:
                f.write(r'\include{%s}' '\n' % section)
            else:
                f.write(r'\input{%s}' '\n' % section)
        f.write(r'\end{document}' '\n')
        f.close()

    def generateHeaderFooters(self):
        try:
            testId = self.params.testId
            name = '%s_headers' % self.params.testId
            self.__addSection(name, False)
            f = open('%s.tex' % name, 'w')
            f.write(r'\pagestyle{fancyplain}' '\n')
            f.write(r'\cfoot{\thepage}' '\n')
            f.write(r'\rfoot{%s %s}' '\n' % (testId, self.__timestamp))
            f.close()
        except:
            print('Error generating headers and footers')
            raise

    def generateTestDescription(self):
        try:
            name = '%s_description' % self.params.testId
            self.__addSection(name, False)
            f = open('%s.tex' % name, 'w')
            f.write(r'\section{Test Description}' '\n\n')
            f.write(r'\subsection{Platform}' '\n\n')
            f.write(r'{\small' '\n')
            f.write(r'\begin{description}' '\n\n')
            (system, node, release, version, machine, processor) = platform.uname()
            f.write(r'  \item[Host name] %s' '\n' % self.__latexEsc(node))
            f.write(r'  \item[Architecure] %s' '\n' % self.__latexEsc(machine))
            f.write(r'  \item[Processor] %s' '\n' % self.__latexEsc(processor))
            f.write(r'  \item[Operating system] %s %s (%s)' '\n' % (self.__latexEsc(system), self.__latexEsc(release), self.__latexEsc(version)))
            f.write(r'\end{description}' '\n\n')
            f.write(r'\subsection{Parameters}' '\n\n')
            f.write(r'\begin{description}' '\n')
            for (k,v) in self.params.aboutThisTest():
                f.write(r'  \item[%s] %s' '\n' % (self.__latexEsc(k), self.__latexEsc(v)))
            f.write(r'\end{description}' '\n\n')
            f.write(r'\subsection{Variants}' '\n\n')
            for testable in self.params.variants:
                f.write(r'\noindent\textsc{%s} : ' '\n' % self.__latexEsc(testable.id()))
                for (key, value) in testable.getConfiguration():
                    f.write(r'\textbf{%s} %s' '\n' % (self.__latexEsc(key), self.__latexEsc(value)))
                f.write('\n')
            f.write(r'\subsection{Suite}' '\n\n')
            f.write(r'\begin{description}' '\n')
            f.write(r'  \item[Name] %s' '\n' % self.__latexEsc(self.params.suiteName))
            batch1 = self.params.batches[0]
            tests  = self.params.globalStats[batch1].tests
            f.write(r'  \item[Cases] %d' '\n' % len(tests))
            f.write(r'\end{description}' '\n\n')
            f.write(r'}' '\n')
            f.close()
        except:
            print('Error generating test description')
            raise

    def generateResponsesSummary(self):
        try:
            name = '%s_resp_summary' % self.params.testId
            self.__addSection(name, False)
            f = open('%s.tex' % name, 'w')
            f.write(r"\section{Summary of the generators' responses}" '\n\n')

            variants = self.params.variants
            baseline = self.params.baseline
            allVariants = variants
            if baseline: allVariants = [self.params.baseline] + allVariants

            for batch in self.params.batches:
                if baseline:
                    baselineStatsList = baseline.statsForBatch(batch).statsList
                globalStats = self.params.globalStats[batch]
                tests   = natsort(globalStats.tests)

                # ------------------------------------------------------------
                # collecting statistical info
                # ------------------------------------------------------------
                if len(variants) > 0:
                    skeys = variants[0].statsForBatch(batch).metrics
                    tkeys = variants[0].statsForBatch(batch).timeMetrics + ["cpu_time"]
                else:
                    skeys = []
                    tkeys = [] 

                totals = {}
                lows  = {}
                highs = {}
                for surfRealiser in allVariants:
                    totals[surfRealiser] = {}
                    lows[surfRealiser] = {}
                    highs[surfRealiser] = {}
                    totalsSf = totals[surfRealiser]
                    lowsSf   = lows[surfRealiser]
                    highsSf  = highs[surfRealiser]

                    stats = surfRealiser.statsForBatch(batch)
                    tests = natsort(stats.tests)

                    for t in tests:
                        if baseline:
                            # make an estimation of difficulty
                            fval = keyPairValOrDefault(baselineStatsList, t, "lex_foot_nodes", 1)
                            sval = keyPairValOrDefault(baselineStatsList, t, "lex_subst_nodes", 1)
                            baselineStatsList[t][hardness] = int(fval) * int(sval)

                        # subtract the lexical selection time from the other surfRealisers
#                        if surfRealiser != baseline:
#                            for k in tkeysOrig:
#                                lexK = "lex" + k
#                                if stats.timingResults[t].has_key(k) and stats.timingResults[t].has_key(lexK):
#                                    stats.timingResults[t][k] -= stats.timingResults[t][lexK]

                        # dump the test case name into the stats to make processing more uniform
                        stats.statsList[t]["test_case"] = t

                        # add a cpu time stat (user + sys)
                        stats.timingResults[t]["cpu_time"] = stats.timingResults[t]["user"] + stats.timingResults[t]["sys"]
                        # dump all timing results into the stats to make processing more uniform
                        for k in tkeys:
                            if hasKeyPair(stats.timingResults, t, k):
                                stats.statsList[t][k] = stats.timingResults[t][k]
                        if stats.overgenList.has_key(t):
                            stats.statsList[t]["passes"] = len(stats.passList[t])
                            stats.statsList[t]["fails"]  = len(stats.failList[t])
                            stats.statsList[t][ovgenKey] = len(stats.overgenList[t])
                            stats.statsList[t]["responses"] = len(stats.responseList[t])

                        # dump some baseline data into the realiser's table to avoid needless cross-referencing
                        # by the person staring at the data
                        if baseline and surfRealiser != baseline:
                            for k in CROSSREF_KEYS:
                                stats.statsList[t][k] = keyPairValOrDefault(baselineStatsList, t, k)

                    # sum up the totals for each one of the known statistical keys
                    for k in pfoKeys + stats.metrics:  #skeys + pfoKeys:
                        totalsSf[k] = 0
                        lowsSf[k] = maxint
                        highsSf[k] = -1
                        for t in tests:
                            if hasKeyPair(stats.statsList, t, k):
                                n = int(stats.statsList[t][k])
                                totalsSf[k] += n
                                lowsSf[k]   =  min(lowsSf[k],n)
                                highsSf[k]  =  max(highsSf[k],n)
                    # sum up total number of passes and fails
                    totalFails  = totalsSf['fails']
                    totalPasses = totalsSf['passes']
                    if totalPasses + totalFails != 0:
                        totalsSf['percent passed'] = 100 * totalPasses / (totalPasses + totalFails)
                    else:
                        totalsSf['percent passed'] = "???"
                    # sum up the times (can be misleading if lots of stuff dies)
                    for k in tkeys:
                        totalsSf[k] = 0
                        lowsSf[k] = maxint
                        highsSf[k] = -1
                    totalAlive = 0
                    for t in tests:
                        if stats.timingResults.has_key(t):
                            for k in tkeys:
                                if stats.timingResults[t].has_key(k):
                                    n = stats.timingResults[t][k]
                                    totalsSf[k] += n
                                    lowsSf[k]  = min(lowsSf[k],n)
                                    highsSf[k] = max(highsSf[k],n)
                            if stats.timingResults[t] != {}: totalAlive += 1
                    totalsSf['not dead']  = totalAlive
                    totalsSf['responses'] = totalPasses + totalsSf[ovgenKey]
                    lowsSf['responses'] = lowsSf['passes'] + lowsSf[ovgenKey]
                    highsSf['responses'] = highsSf['passes'] + highsSf[ovgenKey]

                # end for surfRealiser

                # ------------------------------------------------------------
                # expected results (pass/fail)
                # ------------------------------------------------------------
                f.write(r"\subsection{Expected results}" '\n\n')

                # begin table with a response column for each variant
                f.write(r"\begin{longtable}{|r|p{8cm}||")
                for v in variants:
                    f.write(r"lrr|")
                f.write(r"}" '\n')
                f.write(r"\hline" '\n')
                f.write(r"\textbf{test}")
                f.write(r"& \textbf{expected}")
                for v in variants: f.write(r"& \multicolumn{3}{|c|}{\textbf{%s}}" % v.id())
                f.write(r" \\" '\n')
                f.write(r"\hline" '\n')
                f.write(r"\endhead" '\n')

                def colorbox(x,color='red'):
                    return (r'\colorbox{%s}{%s}' % (color,x))

                def passOrFail(surfRealiser, t):
                      stats = surfRealiser.statsForBatch(batch)
                      failed = stats.failList[t]
                      if stats.died[t]:
                           pf = colorbox('DIED')
                      elif e in failed:
                           pf = colorbox('fail')
                      else:
                           pf = 'pass'
                      formatLen_ = lambda x : str(len(x))
                      resp       = stats.responseList[t]
                      uniqueResp = set(resp)
                      if len(resp) > 0:
                          formatLen = formatLen_
                      else:
                          formatLen = lambda x : colorbox(formatLen_(x))
                      return ([ pf, formatLen(resp), formatLen(uniqueResp) ])

                for t in tests:
                    expected = globalStats.expectedList[t]
                    for e in expected:
                        r = [ t, e ]
                        for v in variants: r = r + passOrFail(v,t)
                        self.writeRow(f, r)
                # end for tests
                f.write(r"\end{longtable}" '\n\n')

                # ------------------------------------------------------------
                # totals
                # ------------------------------------------------------------
                f.write(r"\subsection{Totals and averages}" '\n\n')

                # begin table with a response column for each variant
                f.write(r"\begin{longtable}{|l|")
                for v in allVariants: f.write(r"rl|")
                f.write(r"}" '\n')
                f.write(r"\hline" '\n')
                f.write(r"\textbf{total}")
                for v in allVariants: f.write(r"& \multicolumn{2}{|c|}{\textbf{%s}}" % v.id())
                f.write(r" \\" '\n')
                f.write(r"\hline" '\n')
                f.write(r"\endhead" '\n')

                for s in SUMMARY_SECTIONS:
                    f.write(r"\multicolumn{%s}{|c|}{\textbf{%s}} " % (1 + (2*len(allVariants)), s[0]))
                    f.write(r"\\" '\n')
                    f.write(r"\hline" '\n')
                    for k in s[1]:
                        f.write(r"%s" % self.__statsToHeader(k))
                        for surfRealiser in allVariants:
                            if totals[surfRealiser].has_key(k):
                                t = totals[surfRealiser][k]
                                f.write(r" & %s" % t)
                                if len(tests) > 0:
                                    f.write(r" & av %.1f" % (float(t)/len(tests)))
                                else:
                                    f.write(r" & av ???")
                            else:
                                f.write(r" & \multicolumn{2}{|c|}{n/a}")
                        f.write(r"\\" '\n')
                        f.write(r"\hline" '\n')
                # end for k
                f.write(r"\end{longtable}" '\n\n')

                # ------------------------------------------------------------
                # highs and lows
                # ------------------------------------------------------------
                f.write(r"\subsection{Highs and Lows}" '\n\n')
                # begin table with a response column for each variant
                f.write(r"\begin{longtable}{|l|")
                for v in allVariants: f.write(r"l|")
                f.write(r"}" '\n')
                f.write(r"\hline" '\n')
                f.write(r"\textbf{lows and highs}")
                for v in allVariants: f.write(r"& \textbf{%s}" % v.id())
                f.write(r" \\" '\n')
                f.write(r"\hline" '\n')
                f.write(r"\endhead" '\n')

                for s in SUMMARY_SECTIONS:
                    f.write(r"\multicolumn{%s}{|c|}{\textbf{%s}} " % (1 + len(allVariants), s[0]))
                    f.write(r"\\" '\n')
                    f.write(r"\hline" '\n')
                    for k in s[1]:
                        f.write(r"%s" % self.__statsToHeader(k))
                        for surfRealiser in allVariants:
                            if lows[surfRealiser].has_key(k):
                                f.write(r" & %s to %s" % (lows[surfRealiser][k], highs[surfRealiser][k]))
                            else:
                                f.write(r" & n/a")
                        f.write(r"\\" '\n')
                        f.write(r"\hline" '\n')
                # end for k
                f.write(r"\end{longtable}" '\n\n')



            # end for batch
            f.close()
        except:
            print('Error generating summary of responses!')
            raise

    def generateResponsesAnalysis(self):
        try:
            name = '%s_resp_analysis' % self.params.testId
            self.__addSection(name, False)
            f = open('%s.tex' % name, 'w')
            f.write(r"\section{Analysis}" '\n\n')

            variants = self.params.variants
            baseline = self.params.baseline
            allVariants = variants
            if baseline: allVariants = [self.params.baseline] + allVariants

            for batch in self.params.batches:
                if baseline:
                    baselineStatsList = baseline.statsForBatch(batch).statsList
                globalStats = self.params.globalStats[batch]
                tests   = natsort(globalStats.tests)


                # -------------------------------------------------------------------------
                # preparing for analysis
                # -------------------------------------------------------------------------
                # translate the group boundaries into slices
                slices = {}
                if baseline:
                  for xKey in GROUPS_X.keys():
                      prevG = 1
                      slices[xKey] = []
                      for g in GROUPS_X[xKey]:
                          # gather data for the group
                          relevantTests = []
                          for t in tests:
                              x = int(keyPairValOrDefault(baselineStatsList, t, xKey, 0))
                              if x >= prevG and x <= g: relevantTests.append(t)
                          slices[xKey].append((prevG, g, relevantTests))
                          prevG = g+1

                def gCount(low, high, list):
                    return "%d-%d (%d cases)" % (low, high, len(list))

                # -------------------------------------------------------------------------
                # baseline analysis
                # -------------------------------------------------------------------------

                def baselineAnalysis(description, xKey, bKey, aKey):
                  ptotals = [ 0, 0, 0 ]
                  pcells = []
                  for (gLow, gHigh, relevantTests) in slices[xKey]:
                      subtotals = [ 0, 0, 0 ]
                      for t in relevantTests:
                          keyval = keyPairValOrDefault(baselineStatsList, t, xKey, 0)
                          bValue = keyPairValOrDefault(baselineStatsList, t, bKey, 0)
                          aValue = keyPairValOrDefault(baselineStatsList, t, aKey, 0)
                          subtotals[0] += int(keyval)
                          subtotals[1] += int(bValue)
                          subtotals[2] += int(aValue)
                      for n in range(0, len(subtotals)): ptotals[n] += subtotals[n]
                      caseCount = len(relevantTests)
                      before = subtotals[1]
                      after  = subtotals[2]
                      beforeAvg = divOrNa(before, caseCount, 2)
                      afterAvg  = divOrNa(after, caseCount, 2)
                      diff      = before - after
                      diffAvg   = divOrNa(diff, caseCount, 2)
                      reduction = divOrNa(before, after)
                      pcells.append([ gCount(gLow, gHigh, relevantTests)
                                    , beforeAvg
                                    , afterAvg
                                    , diffAvg
                                    , reduction ])
                  xKeyLong = self.__statsToHeader(xKey, long=True)
                  f.write('\n')
                  self.writeTableHeader(f, [ self.__statsToHeader(xKey)
                                           , self.__statsToHeader(bKey, long=True)
                                           , self.__statsToHeader(aKey, long=True)
                                           , r"-"
                                           , r"$\times$" ])
                  for c in pcells: self.writeRow(f, c)
                  #f.write(texHline)
                  #self.writeRow(f, ptotals + [ptotals[1] - ptotals[2], divOrNa(ptotals[1], ptotals[2])])
                  f.write('\end{tabular}' '\n')

                # -------------------------------------------------------------------------
                # analysis and comparisons
                # -------------------------------------------------------------------------
                bestVariant   = variants[0]
                otherVariants = []

                foundBest = False
                for b in BEST_VAR_NAMES:
                    if foundBest: continue
                    for v in variants:
                        if v.id() == b :
                            bestVariant = v
                            foundBest = True
                for v in variants:
                  if v.id() != bestVariant.id() : otherVariants.append(v)
                comparedVariants = [bestVariant] + otherVariants

                _sliceInfo_ = {}

                def bestVariantAnalysis(description, xKey, yKeyList):
                   cells = {}
                   totals = {}
                   totals['key'] = 0
                   #for v in comparedVariants: totals.comparedVariants[v] = 0
                   for t in tests:
                       key   = keyPairValOrDefault(baselineStatsList, t, xKey, 0)
                       totals['key'] += int(key)
                                   # build up table data for all Ys
                   for yKey in yKeyList:
                       cells[yKey] = {}
                       _sliceInfo_[(xKey,yKey)] = {}
                       totals[yKey] = {}
                       subtotal = {}
                       for v in comparedVariants:
                           totals[yKey][v] = 0
                           _sliceInfo_[(xKey,yKey)][v] = {}
                       for (_, gHigh, relevantTests) in slices[xKey]:
                           # initialise data gathering
                           for v in comparedVariants: subtotal[v] = 0
                           usableTestCount = 0
                           for t in relevantTests:
                               usableTest = True
                               for v in comparedVariants:
                                   if not v.statsForBatch(batch).statsList[t].has_key(yKey):
                                       print "discarding %s" % t
                                       usableTest = False
                                       break
                               if not usableTest: continue # discard this if any of them fail
                               usableTestCount += 1
                               for v in comparedVariants:
                                   vStats = v.statsForBatch(batch).statsList
                                   yValue = float(vStats[t][yKey])
                                   subtotal[v] += yValue
                                   totals[yKey][v] += yValue
                           # calculate averages for the current slice
                           caseCount = usableTestCount # len(relevantTests)
                           if caseCount == 0: caseCount = 1
                           best = subtotal[bestVariant]
                           bestAvg = divOrNa(best, caseCount, 2)
                           _sliceInfo_[(xKey,yKey)][bestVariant][gHigh] = bestAvg
                           cells[yKey][gHigh] = []
                           for v in otherVariants:
                               other = subtotal[v]
                               otherAvg = divOrNa(other, caseCount, 2)
                               _sliceInfo_[(xKey,yKey)][v][gHigh] = otherAvg
                               cells[yKey][gHigh].append(otherAvg)
                           cells[yKey][gHigh].append(bestAvg)
                           # print out the reductions
                           for v in otherVariants:
                               diff = subtotal[v] - best
                               diffAvg = divOrNa(diff, caseCount, 2)
                               cells[yKey][gHigh].append(diffAvg)
                               cells[yKey][gHigh].append(divOrNa(subtotal[v], best))

                   # format everything into cells
                   tableHeaderRow = [ self.__statsToHeader(xKey, long=True) ] # will be added to
                   for yKey in yKeyList:
                     for v in otherVariants:
                         tableHeaderRow.append(v.id())
                     tableHeaderRow.append(bestVariant.id())
                     for v in otherVariants:
                         tableHeaderRow.append(r"$-$")
                         tableHeaderRow.append(r"$\times$")
                   tableBodyRows   = []
                   for (gLow, gHigh, relevantTests) in slices[xKey]:
                       tmpRow = [ gCount(gLow, gHigh, relevantTests) ]
                       for yKey in yKeyList: tmpRow += cells[yKey][gHigh]
                       tableBodyRows.append(tmpRow)
                   tableTotalsRow = [ str(totals['key']) ]
                   for yKey in yKeyList:
                       for v in otherVariants:
                           tableTotalsRow.append(totals[yKey][v])
                       best = totals[yKey][bestVariant]
                       tableTotalsRow.append(best)
                       for v in otherVariants:
                           other = totals[yKey][v]
                           times = divOrNa(other, best)
                           tableTotalsRow.append(other-best)
                           tableTotalsRow.append(times)

                   # -- write everything out
                   xKeyLong = self.__statsToHeader(xKey, long=True)
                   # table header
                   f.write('\n')
                   f.write(r"\begin{tabular}{|r|")
                   for y in yKeyList:
                       f.write(r"|p{3cm}|p{3cm}||")
                       for v in otherVariants:
                           f.write("r|r|")
                   f.write(r"}" '\n')
                   f.write(texHline)
                   for yKey in yKeyList:
                       yKeyLong = self.__statsToHeader(yKey, long=True)
                       f.write(r"& \multicolumn{%d}{|c|}{\textbf{%s}}"
                               % (len(comparedVariants), yKeyLong))
                       for v in otherVariants:
                           f.write(r"& \multicolumn{2}{c|}{red. for %s}"  % v.id())
                   f.write(r"\\" '\n')
                   f.write(texHline)
                   self.writeRow(f, tableHeaderRow, header=True)
                   f.write(texHline)
                   # table body
                   for r in tableBodyRows: self.writeRow(f, r)
                   # table end
                   #f.write(texHline)
                   #self.writeRow(f, tableTotalsRow)
                   f.write('\end{tabular}' '\n')

                # -------------------------------------------------------------------------
                # grouped plots / graphs
                # -------------------------------------------------------------------------

                def groupedPlot(description, xKey, yKeyList):
                  for yKey in yKeyList:
                    series = []
                    # xValues: add a point for every test case that is an instance of the x value
                    # tCases : get a sorted list of test cases (grouped by category)
                    xValues = []
                    tCases  = []
                    for (_,h,_) in slices[xKey]:
                        xValues.append(h)
                    # note: here we want the baseline AFTER, so that the colours
                    # of our graph don't change wherever the baseline appears
                    variantsAndBaseline = variants + [baseline]
                    # FIXME: support multiple batches, by averaging or something
                    for v in variants: #variantsAndBaseline:
                        yValues = [ _sliceInfo_[(xKey,yKey)][v][g] for (_,g,_) in slices[xKey] ]
                        series.append((v.id(), yValues))
                    graphicFile = ("plot-slices-%s-vs-%s" % (xKey, yKey))
                    o = self.__newOctaveScript(graphicFile)
                    xKeyLong = self.__statsToHeader(xKey, long=True)
                    yKeyLong = self.__statsToHeader(yKey, long=True)
                    generatePlot(output = o,
                                       xdata  = xValues,
                                       series = series,
                                       title  = "grouped %s for %s" % (yKeyLong, xKeyLong),
                                       xlabel = xKeyLong,
                                       ylabel = yKeyLong,
                                       ylogscale = False,
                                       xlogscale = False,
                                       plotOutput = graphicFile)
                    caption = "grouped %s for %s" % (yKeyLong, xKeyLong),
                    self.__writeLatexFigure(f, caption, graphicFile)

                # -------------------------------------------------------------------------
                # run the analyses / plots / graphs
                # -------------------------------------------------------------------------

                if len(variants) > 1:
                  # run a baseline analysis for each x and y key
                  for (description, bKey, aKey) in BASELINE_YS:
                      f.write('\n'r"\subsection{%s}" '\n' % description)
                      f.write(texBeginCenter)
                      for xKey in COMPARISON_XS:
                          baselineAnalysis(description, xKey, bKey, aKey)
                          f.write('\n')
                          f.write(texVspaceHalf)
                      f.write(texEndCenter)

                  # do a best variant analysis for each x key and set of y keys
                  for (description, yKeyList) in COMPARISON_YS_SETS:
                      f.write('\n'r"\subsection{Comparisons on %s}" '\n' % description)
                      f.write(texBeginCenter)
                      for xKey in COMPARISON_XS:
                          bestVariantAnalysis(description, xKey, yKeyList)
                          f.write('\n')
                          f.write(texVspaceHalf)
                      f.write(texEndCenter)

                  if DO_GROUPED_PLOTS:
                      for (description, yKeyList) in COMPARISON_YS_SETS:
                          f.write(texBeginCenter)
                          for xKey in COMPARISON_XS:
                              groupedPlot(description, xKey, yKeyList)
                              f.write(texVspaceHalf)
                          f.write(texEndCenter)

                # -------------------------------------------------------------------------
                # plots / graphs
                # -------------------------------------------------------------------------
                # categorise the test cases so that we can rank them by difficulty
                # and divide them by number of substitutions, adjunctions etc
                if baseline:
                  cInstances = {}
                  cInstances["test_case"] = {}
                  i = 0
                  for t in tests: # create an artificial 'metric' for test case
                      i+=1
                      cInstances["test_case"][i] = [t]
                  for k in baseline.statsForBatch(batch).metrics: # k would be something like substitutions
                      cInstances[k] = {}
                      for t in tests:
                          # how many foos (e.g. substitutions) does test case t have?
                          if hasKeyPair(baselineStatsList, t, k):
                              c = int(keyPairValOrDefault(baselineStatsList, t, k, 0))
                              # add t into the list of things which have c foos
                              if not cInstances[k].has_key(c): cInstances[k][c] = []
                              cInstances[k][c].append(t)

                  if len(PLOT_BY_YS) > 0: f.write(r"\newpage" '\n')

                  # plot things!
                  for (xMetric,xLog) in PLOT_BY_XS:
                    for (yMetric,yLog) in PLOT_BY_YS:
                      series = []
                      # xValues: add a point for every test case that is an instance of the x value
                      # tCases : get a sorted list of test cases (grouped by category)
                      xValues = []
                      tCases  = []
                      xKeys = cInstances[xMetric].keys()
                      xKeys.sort()
                      for k in xKeys:
                          xValues = xValues + [k for i in cInstances[xMetric][k]]
                          tCases  = tCases  + cInstances[xMetric][k]
                      # note: here we want the baseline AFTER, so that the colours
                      # of our graph don't change wherever the baseline appears
                      variantsAndBaseline = variants + [baseline]
                      # FIXME: support multiple batches, by averaging or something
                      for surfRealiser in variants: #variantsAndBaseline:
                          stats = surfRealiser.statsForBatch(batch)
                          yValues = []
                          if hasKeyPair(stats.statsList, tCases[0], yMetric):
                              for t in tCases:
                                  yValues.append(keyPairValOrDefault(stats.statsList, t, yMetric, -1))
                              series.append((surfRealiser.id(), yValues))
                      graphicFile = ("plot-%s-vs-%s" % (xMetric, yMetric))
                      o = self.__newOctaveScript(graphicFile)
                      xMetricLong = self.__statsToHeader(xMetric, long=True)
                      yMetricLong = self.__statsToHeader(yMetric, long=True)
                      generatePlot(output = o,
                                   xdata  = xValues,
                                   series = series,
                                   title  = "%s for %s" % (yMetricLong, xMetricLong),
                                   xlabel = xMetricLong,
                                   ylabel = yMetricLong,
                                   ylogscale = yLog,
                                   xlogscale = xLog,
                                   plotOutput = graphicFile)
                      caption = "%s for %s" % (yMetricLong, xMetricLong),
                      self.__writeLatexFigure(f, caption, graphicFile)

                # -------------------------------------------------------------------------
                # print out statistics for each surface realiser
                # -------------------------------------------------------------------------
                f.write(r"\newpage" '\n')
                f.write(r"\section{Raw data}" '\n')
                for surfRealiser in allVariants:
                    stats = surfRealiser.statsForBatch(batch)

                    skeys = stats.metrics
                    tkeys = [ 'cpu_time' ]
                    allKeys = tkeys + skeys
                    if surfRealiser != baseline:
                        allKeys = CROSSREF_KEYS + allKeys
                    else:
                        allKeys = allKeys

                    def writeStatTable(keys):
                        self.writeTableHeader(f, [self.__statsToHeader(k) for k in keys], longtable=True)
                        # sort the tests by some criterion
                        sortedTests = tests
                        if baseline and SORT_TESTS:
                            catVal = lambda t: int(keyPairValOrDefault(baselineStatsList, t, SORT_TESTS_BY, 0))
                            sortedTests.sort(lambda t1,t2: cmp(catVal(t1), catVal(t2)))
                        # now print out the column body
                        for t in sortedTests:
                            self.writeRow(f, [ keyPairValOrDefault(stats.statsList, t, k) for k in keys ])
                        f.write('\end{longtable}' '\n')

                    # print out the statistics
                    if len(tests) > 0 and len(skeys) + len(tkeys) > 0:
                        f.write(r"\subsection{Full statistics for %s}" '\n' % surfRealiser.id())
                        writeStatTable(['test_case'] + [ k for k in allKeys if "pol_" not in k ])
                        polKeys = [ k for k in allKeys if "pol_" in k ]
                        if len(polKeys) > 0: writeStatTable(['test_case'] + polKeys)
              # end for surfRealiser
            # end for batch
            f.close()
        except:
            print('Error generating analysis of responses!')
            raise

    def writeRow(self, f, cells, header=False):
        def prettify(first, x):
            if header:
                if first: coltype = "|c|"
                else:     coltype = "c|"
                return (r"\multicolumn{1}{%s}{\textbf{%s}}"'\n' % (coltype,x))
            else:
                return x
        if len(cells) > 0:
            f.write(prettify(True, cells[0]))
            for c in cells[1:]:
                f.write(r"& %s" % prettify(False, c))
        f.write(r"\\ \hline" '\n')

    def writeTableHeader(self, f, cells, longtable=False):
        if longtable:
           t = 'longtable'
        else:
           t = 'tabular'
        f.write(r"\begin{%s}{|" % t)
        for k in cells: f.write(r"r|")
        f.write(r"}" '\n')
        f.write(r"\hline" '\n')
        self.writeRow(f, cells, header=True)
        f.write(r"\hline" '\n')
        if longtable: f.write(r"\endhead" '\n')

    def __statsToHeader(self, stat, long=False):
        replacePairs = [ ("test_case"      , "test", "")
                       , ("substitutions"  , "subst", "")
                       , ("iterations"     , "iters", "")
                       , ("chart_size"     , "chart sz", "")
                       , ("adjunctions"    , "adj", "")
                       , ("sem_literals"   , "lits", "literals")
                       , ("lex_nodes"      , "nodes", "")
                       , ("lex_trees"      , "trees", "")
                       , ("lex_foot_nodes" , "\# *", "foot nodes")
                       , ("lex_subst_nodes", "\# $\downarrow$", "subst nodes")
                       , ("plex_nodes"      , "p nodes", "(pol) nodes")
                       , ("plex_trees"      , "p trees", "(pol) trees")
                       , ("plex_foot_nodes" , "p\# *", "(pol) foot nodes")
                       , ("plex_subst_nodes", "p\# $\downarrow$", "(pol) subst nodes")
                       , ("dlex_nodes"      , r"$\Delta$ nodes", r"$\Delta$ nodes")
                       , ("dlex_trees"      , r"$\Delta$ trees", r"$\Delta$ trees")
                       , ("dlex_foot_nodes" , r"$\Delta$ \# *", r"$\Delta$ foot nodes")
                       , ("dlex_subst_nodes", r"$\Delta$ \# $\downarrow$", r"\delta subst nodes")
                       , ("estimated_difficulty", r"*$\times\downarrow$", "estimated_difficulty (s X f)")
                       , ("pol_total_states", "pt s", "")
                       , ("pol_total_trans" , "pt t", "")
                       , ("pol_max_states"  , "pm s", "")
                       , ("pol_max_trans"   , "pm t", "")
                       , ("pol_seed_paths"  , "ps p", "paths possible")
                       , ("pol_used_paths"  , "pu p", "paths explored")
                       , ("pol_used_bundles", "pu b", "")
                       , ("root_cat_discards", "r discards", "rt discards")
                       , ("cpu_time", "CPU t", "CPU time")
                        ]
        for (f,t,l) in replacePairs:
            if stat == f:
                if not long:
                    stat = t
                elif l != "":
                    stat = l
        return re.compile('_').sub(' ', stat)

    def __sumSeries(self, series1, series2):
        return map(lambda (v1, v2): map(lambda (x,y): x+y, zip(v1,v2)), zip(series1, series2))

    def __writeLatexFigure(self, o ,caption, graphicFile):
        o.write(r'\begin{figure}[!htb]' '\n')
        o.write(r'\includegraphics[keepaspectratio=true]{%s}' '\n' % graphicFile)
        o.write(r'\caption{%s}' '\n' % caption)
        o.write(r'\end{figure}' '\n\n')

    def __writeLatexEnumerate(self, o, lst):
        if len(lst) > 0:
            o.write(r'\begin{enumerate}' '\n')
            for i in lst:
                o.write(r'\item{%s}' '\n' % i)
            o.write(r'\end{enumerate}' '\n\n')

    def __latexEsc(self, s):
        return s.replace('_', r'\_').replace('^', r'\^').replace('#', r'\#')

    def __addSection(self, name, useInclude = True):
        self.__latexSections.append((name, useInclude))

    def __newOctaveScript(self, name):
        self.__octaveScripts.append(name)
        return open(name, 'w')


# vim: sw=4 expandtab:
