#!/usr/bin/env python
from __future__ import print_function, division

import matplotlib.pyplot as plt
import pyexotica as exo
import numpy as np
from collections import OrderedDict


def test_solver(solver_name, new_boxqp=False):
    # config = '{exotica_examples}/resources/configs/dynamic_time_indexed/13_control_limited_ddp_quadrotor.xml'
    config = '{exotica_examples}/resources/configs/dynamic_time_indexed/07_control_limited_ddp_cartpole.xml'
    _, problem_config = exo.Initializers.load_xml_full(config)
    problem = exo.Setup.create_problem(problem_config)
    solver_config = (solver_name, {u'Name': u'MySolver', u'UseNewBoxQP': new_boxqp, u'BoxQPUseCholeskyFactorization': True})
    solver = exo.Setup.create_solver(solver_config)
    solver.specify_problem(problem)

    solution = solver.solve()

    print('Solver terminated with:', problem.termination_criterion)
    print('Solver took:', solver.get_planning_time())

    costs = problem.get_cost_evolution()
    return costs


if __name__ == "__main__":
    solvers_to_test = []
    solver_families = ['DDP']  # , 'ILQG', 'ILQR']
    initializers = exo.Setup.get_initializers()
    for initializer in initializers:
        for solver_family in solver_families:
            if 'Abstract' in initializer[0] or 'Sparse' in initializer[0]:
                continue
            if solver_family in initializer[0]:
                solvers_to_test.append(initializer[0])

    results = OrderedDict()

    for solver_name in solvers_to_test:
        pretty_name = solver_name.replace("exotica/", "").replace("Solver", "")
        if 'ControlLimited' in solver_name:
            for new_boxqp in [True, False]:
                print('Testing', pretty_name, 'with *new* BoxQP' if new_boxqp else 'with *old* BoxQP')
                pretty_boxqp = ' (Crocoddyl-BoxQP)' if new_boxqp else ' (Exotica-BoxQP)'
                results[pretty_name + pretty_boxqp] = test_solver(solver_name, new_boxqp)
        else:
            print('Testing', pretty_name)
            results[pretty_name] = test_solver(solver_name)

    fig = plt.figure(1, (12, 6))
    plt.title('Comparison of DDP-like solvers on a cart-pole swing-up task')

    plt.subplot(1, 2, 1)
    for setting in results:
        plt.plot(results[setting][1], label=setting)
    plt.yscale('log')
    #plt.xscale('log')
    plt.ylabel('Cost')
    plt.xlabel('Iterations')
    plt.legend()

    plt.subplot(1, 2, 2)
    for setting in results:
        plt.plot(results[setting][0], results[setting][1], label=setting)
    plt.yscale('log')
    plt.ylabel('Cost')
    plt.xlabel('Time (s)')
    plt.legend()

    plt.tight_layout()
    # fig.savefig('ddp_family_algorithm_comparison.png')
    plt.show()
