#!/usr/bin/env python
from __future__ import print_function, division

import scipy.optimize as opt
import matplotlib.pyplot as plt
import pyexotica as exo
import numpy as np
from time import time, sleep

exo.Setup.init_ros()
sleep(0.2)

prob = exo.Initializers.load_xml(
    '{exotica_examples}/resources/configs/example_dynamic_time_indexed_problem.xml')
problem = exo.Setup.create_problem(prob)

print("Constant acceleration to show evolution of system dynamics:")
s = time()
for t in range(problem.T - 1):
    u_rand = 0.1 * np.ones((problem.N, 1))
    problem.update(u_rand, t)
    problem.get_scene().get_kinematic_tree().publish_frames()
    sleep(problem.tau)
e = time()
print("Time taken to roll-out:", e-s)

'''
# Visualize: We should see a quadratically increasing position and linear increasing velocity.
plt.plot(problem.X[:problem.N, :].T)
plt.title("Positions")
plt.show()

plt.plot(problem.X[problem.N:, :].T)
plt.title("Velocities")
plt.show()
'''

# Set costs: only penalise states at end of trajectory (Q_f)
for t in range(problem.T - 1):
    problem.set_Q(0. * problem.get_Q(t), t)

# Set random guess (initial control trajectory)
problem.U = 0.1 * np.random.rand(problem.U.shape[0], problem.U.shape[1])

# Set target (final state)
X_star = problem.X_star.copy()
X_star[0:7, -1] = np.array([0, 0, 0, np.pi/2., 0, -0.5, 0]).T
problem.X_star = X_star

# Compute cost of roll-out of initial control trajectory
cost = 0.0
for t in range(problem.T):
    cost += problem.get_state_cost(t)
for t in range(problem.T - 1):
    cost += problem.get_control_cost(t)
print("Initial cost:", cost)

# Set up simple optimization problem
def cost_function(u):
    for t in range(problem.T - 1):
        problem.update(u[7 * t: 7 * t + 7], t)
    cost = 0.0
    for t in range(problem.T):
        cost += problem.get_state_cost(t)
    for t in range(problem.T - 1):
        cost += problem.get_control_cost(t)
    return cost

'''
# TODO: This needs to use a recursive update (backwards)
def cost_jacobian(u):
    for t in range(problem.T - 1):
        problem.update(u[7 * t: 7 * t + 7], t)
    # Update value function recursively - TODO
    jac = np.zeros(u.shape[0],)
    for t in range(1, problem.T):
        # print(t, problem.get_state_cost(t), problem.get_state_cost_jacobian(t))
        jac[7 * t - 7: 7 * t] += problem.get_state_cost_jacobian(t)
    # for t in range(0, problem.T-1):
    #    jac[7 * t : 7 * t + 7] += problem.get_control_cost_jacobian(t)
    return jac

# from scipy.optimize import check_grad
# print(check_grad(cost_function, cost_jacobian, problem.U.flatten()))
'''

s = time()
u_init = problem.U.flatten()
res = opt.minimize(cost_function, x0=problem.U.flatten(), method='BFGS')
e = time()
print("Optimal solution found in", e-s)

for t in range(problem.T - 1):
    problem.update(res.x[problem.N * t:problem.N * t + problem.N], t)

from pyexotica.publish_trajectory import publish_trajectory
publish_trajectory(problem.X[0:7,:].T, problem.T * problem.tau, problem)

# Visualize final result - the velocity should be smooth
plt.plot(problem.X[:problem.N, :].T)
plt.title("Positions")
plt.show()

plt.plot(problem.X[problem.N:, :].T)
plt.title("Velocities")
plt.show()

plt.plot(problem.U.T)
plt.title("Accelerations")
plt.show()
