#!/usr/bin/env python

# Usage:
#   python example_dynamic_time_indexed [problem_id]
# 
from __future__ import print_function, division

import pyexotica as exo
import exotica_ilqr_solver_py, exotica_ddp_solver_py, exotica_ilqg_solver_py
from pyexotica.publish_trajectory import *
import signal, sys
from time import time, sleep
import numpy as np

exo.Setup.init_ros()
sleep(0.2)

configs = [
    '{exotica_examples}/resources/configs/dynamic_time_indexed/01_ilqr_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/02_lwr_task_maps.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/03_ilqr_valkyrie.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/04_analytic_ddp_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/05_analytic_ddp_lwr.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/06_analytic_ddp_valkyrie.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/07_control_limited_ddp_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/08_control_limited_ddp_lwr.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/09_control_limited_ddp_valkyrie.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/10_ilqg_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/11_ilqg_lwr.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/12_ilqg_valkyrie.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/13_control_limited_ddp_quadrotor.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/14_rrt_cartpole.xml',
    '{exotica_examples}/resources/configs/dynamic_time_indexed/15_rrt_quadrotor.xml',
]

problem_idx = int(sys.argv[1]) - 1
print('Running with config {0}'.format(configs[problem_idx]))

solver = exo.Setup.load_solver(configs[problem_idx])
problem = solver.get_problem()
solution = solver.solve()

if solution.shape[0] == 0:
    print('No solution found!')
    exit(0)

costs = problem.get_cost_evolution()
import matplotlib.pyplot as plt
plt.figure()
plt.plot(range(len(costs[1])), costs[1])
plt.xlabel('Iteration')
plt.ylabel('Cost')
plt.show()


visualization = exo.VisualizationMoveIt(problem.get_scene())
visualization.display_trajectory(
    problem.X[:problem.num_positions, :].T
)

# publish_trajectory(problem.X[:problem.num_positions, :].T, problem.T * problem.tau, problem)
# exit(0)

ds = problem.get_scene().get_dynamics_solver()
dt = ds.dt

while True:
    try:
        x = problem.X[:,0]

        for t in range(1000):
            if t < problem.T - 2:
                u = solution[t,:]
            else:
                # u = solution[problem.T - 2, :]
                u = solver.get_feedback_control(
                    x, problem.T - 2
                )

            update_start = time()
            
            big_X = problem.X
            big_X[:, 1] = x
            problem.X = big_X

            problem.update(u, 1)
            x = problem.X[:,2]
            
            problem.get_scene().get_kinematic_tree().publish_frames()

            update_end = time()
            if problem.tau > update_end - update_start:
                sleep(problem.tau - (update_end - update_start))
                # sleep(0.1)
            
        sleep(1)
    except KeyboardInterrupt:
        break

