#!/usr/bin/env python

# Usage:
#   python example_dynamic_time_indexed [problem_id]
# 
from __future__ import print_function, division

import pyexotica as exo
import exotica_ilqr_solver_py, exotica_ddp_solver_py, exotica_ilqg_solver_py
from pyexotica.publish_trajectory import *
import signal, sys
from time import time, sleep
import numpy as np

exo.Setup.init_ros()
sleep(0.2)

configs = [
    '{exotica_examples}/resources/configs/dynamic_time_indexed/01_ilqr_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/02_lwr_task_maps.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/03_ilqr_valkyrie.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/04_analytic_ddp_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/05_analytic_ddp_lwr.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/06_analytic_ddp_valkyrie.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/07_control_limited_ddp_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/08_control_limited_ddp_lwr.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/09_control_limited_ddp_valkyrie.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/10_ilqg_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/11_ilqg_lwr.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/12_ilqg_valkyrie.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/13_control_limited_ddp_quadrotor.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/14_rrt_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/15_rrt_quadrotor.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/16_kpiece_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/17_quadrotor_collision_avoidance.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/18_ilqr_pendulum.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/19_ddp_quadrotor_sphere.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/20_sparse_ddp_pendulum.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/21_sparse_ddp_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/22_boxfddp_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/23_quadrotor_rotating_inplace.xml',
]

problem_idx = int(sys.argv[1]) - 1
print('Running with config {0}'.format(configs[problem_idx]))

solver = exo.Setup.load_solver(configs[problem_idx])
problem = solver.get_problem()
solution = solver.solve()

print('Solver terminated with:', problem.termination_criterion)
print('Solver took:', solver.get_planning_time())
if problem.termination_criterion == exo.TerminationCriterion.Divergence:
    print('Solver Diverged!')
    # exit(0)

if solution.shape[0] == 0:
    print('No solution found!')
    exit(0)

costs = problem.get_cost_evolution()
import matplotlib.pyplot as plt

# print(['{0:.3f}'.format(x[0]) for x in solution])

# Controls
plt.figure(figsize=(7, 7))
plt.title('Solution (Controls)')
plt.plot(solution)
plt.grid()
plt.hlines(0, 0, len(solution[:,0]))
plt.show()

# State
plt.figure()
plt.title('State')
plt.plot(problem.X[:problem.get_scene().num_positions,:].T)
plt.show()

plt.figure()
plt.plot(range(len(costs[1])), costs[1])
plt.xlabel('Iteration')
plt.ylabel('Cost')
plt.yscale('log')
plt.xlim(0, len(costs[1])-1)
plt.grid()
plt.show()

visualization = exo.VisualizationMoveIt(problem.get_scene())
visualization.display_trajectory(
    problem.X[:problem.num_positions, :].T
)

publish_trajectory(problem.X[:problem.num_positions, :].T, problem.T * problem.tau, problem)
# exit(0)

signal.signal(signal.SIGINT, sig_int_handler)
while True:
    try:
        s = time()
        for t in range(problem.T - 1):
            update_start = time()
            u = solution[t,:]
            
            # if np.abs(u) < 1:
                # u = [0]
            #print(u)

            # u = solver.get_feedback_control(
            #     problem.X[:,t], t
            # )
            problem.update(u, t)
            problem.get_scene().get_kinematic_tree().publish_frames()
            update_end = time()
            time_taken = update_end - update_start
            if problem.tau > time_taken:
                sleep(problem.tau - time_taken)
        e = time()
        sleep(1)
        print("Time taken to roll-out:", e-s)
    except KeyboardInterrupt:
        break
