#!/usr/bin/env python
"""Tool to create convergence plots for Amp."""

import matplotlib
# The 'Agg' command must be *before* all other matplotlib imports for
# headless operation.
matplotlib.use('Agg')

from optparse import OptionParser
from matplotlib import pyplot
from matplotlib.backends.backend_pdf import PdfPages

from amp.analysis import read_trainlog, plot_convergence


parser = OptionParser(
    usage='usage: %prog logfile [plotfile]\n Create convergence plot.'
          ' logfile is amp log file; plotfile is an optional filename '
          'for the output that must end in pdf. Creates a multi-page'
          ' PDF in the case of many convergence attempts in a single '
          'file.')
options, args = parser.parse_args()

if len(args) not in [1, 2]:
    raise RuntimeError('Bad number of arguments.')
plotfile = 'convergence.pdf'
if len(args) == 2:
    plotfile = args.pop(-1)
logfile = args.pop(0)

data_dictionaries = read_trainlog(logfile, verbose=False, multiple=True)
print('{:d} convergence attempt(s) found.'.format(len(data_dictionaries)))
with PdfPages(plotfile) as pdf:
    for index, data in enumerate(data_dictionaries):
        print('Plotting data series {:d}.'.format(index))
        fig = plot_convergence(data=data, plotfile=None)
        pdf.savefig(fig)
        pyplot.close(fig)
