#!/usr/bin/env python
"""An experiment comparing numerical optimization algorithms.

This example illustrates pyexperiment's utilities in an example concerning
scipy's numerical optimization function minimize.

Note that this is by no means a real benchmark of optimization methods, but
just a slightly more realistic usage example for the `pyexperiment` library.

Written by Peter Duerr
"""

# Sensible imports from the future (optional)
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import

import numpy as np
from scipy.optimize import minimize
# from itertools import count as count_up
from datetime import datetime
from functools import partial
import seaborn
from matplotlib import pyplot as plt

# For python2.x compatibility
# from six.moves import range  # pylint: disable=redefined-builtin, import-error

from pyexperiment import experiment
from pyexperiment import state
from pyexperiment import conf
from pyexperiment import log
from pyexperiment.replicate import replicate, collect_results
# from pyexperiment.utils.plot import AsyncPlot
from pyexperiment.utils.plot import setup_figure

METHODS = ['Nelder-Mead',
           'Powell',
           'CG',
           'BFGS']
"""Optimization algorithms that do not require explicit Jacobian
"""

# The dimensionality of the problem
conf['N'] = 10

# Do a few replicates
conf['n_replicates'] = 50

# Plot progress
conf['progress_plot'] = False
conf['progress_plot_update_interval'] = 0.5

# Load the state before the experiment is run and save it after the
# experiment is done
conf['pyexperiment.load_state'] = True
conf['pyexperiment.save_state'] = True
conf['pyexperiment.state_filename'] = 'optimize_result.h5f'

# Fix the random seed, so results are reproducible
conf['random_seed'] = 123


def objective(x):
    """The objective function to be optimized, insert any interesting function
    """
    # For example, the Rosenbrock function
    alpha = 1e2
    return sum(alpha * (x[:-1]**2 - x[1:])**2 + (1. - x[:-1])**2)


def run_method(method):
    """Run...
    """
    i = 1
    log.debug("Running '%s', replicate %d", method, i)
    tic = datetime.now()
    # callback = create_plot_callback(i)
    result = minimize(objective,
                      2 * (np.random.rand(conf['N']) - 1.0),
                      method=method,
                      # callback=callback,
                      options={'maxfev' : 100000,
                               'maxiter' : 100000})
    # options={'maxiter' : 100})
    # callback(None)
    result.used_time = (datetime.now() - tic).total_seconds()
    # state[method].append(result)
    state['result'] = result
    log.debug("%s -> %f", result.message, result.fun)

def single_method(method=None):
    """Minimizes the objective function using the given algorithm
    """
    if method is None:
        method = 'Nelder-Mead'
    else:
        if method not in METHODS:
            print("Method not available, choose from %s"
                  % [m.encode() for m in  METHODS])
            return

    if conf['progress_plot']:
        pass
        # async_plotter = AsyncPlot('Progress for ' + method,
        #                           labels=('Iteration', 'Objective Value'),
        #                           y_scale='log')

        # def create_plot_callback(replicate):
        #     """Creates a callback that plots the objective value and a counter
        #     """
        #     # Create a color for this replicate's plot
        #     color = seaborn.color_palette(
        #         conf['pyexperiment.plot.seaborn.palette_name'],
        #         conf['n_replicates'])[replicate]
        #     # The x-axis is just counting the calls
        #     counter = count_up()

        #     # Make progress plot efficient by buffering data
        #     x_buffer = []
        #     y_buffer = []
        #     last_update = [datetime.now()]

        #     def callback(x):
        #         """Compute current objective value and plot it
        #         """
        #         now = datetime.now()
        #         # Append to the buffer
        #         if x is not None:
        #             objective_value = objective(x)
        #             x_buffer.append(next(counter))
        #             y_buffer.append(objective_value)

        #         # After the last iteration, make sure the plot is complete by
        #         # passing None
        #         if ((x is None
        #              or (now - last_update[0]).total_seconds() > conf[
        #                  'progress_plot_update_interval'])):
        #             # Buffer the buffer (otherwise the plotter will not
        #             # receive anything)
        #             x = list(x_buffer)
        #             y = list(y_buffer)
        #             # Plot
        #             async_plotter.plot(x, y, '*', color=color)
        #             # Store the time, clear the buffers
        #             last_update[0] = now
        #             x_buffer[:] = []
        #             y_buffer[:] = []

        #     return callback
    else:
        # create_plot_callback = lambda: None
        pass

    # for i in range(conf['n_replicates']):
    replicate(partial(run_method, method), parallel=True)
    state[method] = collect_results('result')

    return [result.fun for result in state[method]]


def plot_results():
    """Compare the results of previous runs in a boxplot
    """
    fmin_data = []
    success_rate = []
    feval_data = []
    time_data = []
    for key in METHODS:
        if key in state:
            fmin_data.append([objective(result.x)
                              for result in state[key]])
            feval_data.append([result.nfev
                               for result in state[key]])
            success_rate.append(sum([1
                                     for result in state[key]
                                     if result.success]) / len(state[key]))
            time_data.append([result.used_time
                              for result in state[key]])

    setup_figure('Results: Minima')
    seaborn.boxplot(np.array(fmin_data).T, notch=True)
    labels = [key for key in METHODS if key in state]
    plt.xticks(np.arange(1, len(labels) + 1), labels)
    # plt.yscale('log')

    setup_figure('Results: Success Rate')
    seaborn.barplot(np.arange(0, len(success_rate)),
                    np.array(success_rate))
    labels = [key for key in METHODS if key in state]
    plt.xticks(np.arange(0, len(labels)), labels)

    setup_figure('Results: Function Evaluations')
    seaborn.boxplot(np.array(feval_data).T, notch=True)
    labels = [key for key in METHODS if key in state]
    plt.xticks(np.arange(1, len(labels) + 1), labels)

    setup_figure('Results: Time')
    seaborn.boxplot(np.array(time_data).T, notch=True)
    labels = [key for key in METHODS if key in state]
    plt.xticks(np.arange(1, len(labels) + 1), labels)

    plt.show(block=True)


def run_comparison():
    """Optimize the objective function with all available methods
    """
    for method in METHODS:
        single_method(method)

    plot_results()


if __name__ == '__main__':
    experiment.main(commands=[single_method, run_comparison, plot_results],
                    description="Compare optimization methods")
