#!/usr/bin/env python
from __future__ import print_function

import math
from time import time, sleep

import rospy
from visualization_msgs.msg import Marker, MarkerArray

import numpy as np
import pyexotica as exo
from pyexotica.publish_trajectory import *

import sys

from collections import OrderedDict
import matplotlib.pyplot as plt

if __name__ == "__main__":
    if len(sys.argv) != 2:
        print('Invalid number of input arguments.')
        print('Usage: ./example_sparseddp_valkyrie l2|pseudo_huber')
        exit(0)
    
    loss = sys.argv[1]
    if loss not in ['l2', 'pseudo_huber']:
        print('Unknown loss {0}.'.format(loss))
        print('Usage: ./example_sparseddp_valkyrie l2|pseudo_huber')
        exit(0)

    exo.Setup.init_ros()

    if loss == 'l2':
        solver = exo.Setup.load_solver('{exotica_examples}/resources/configs/example_sparseddp_valkyrie_l2.xml')
    else:
        solver = exo.Setup.load_solver('{exotica_examples}/resources/configs/example_sparseddp_valkyrie_pseudo_huber.xml')
    
    problem = solver.get_problem()
    scene = problem.get_scene()
    vis = exo.VisualizationMoveIt(scene)
    solution = solver.solve()

    # This is now set in the XML
    palm_target = problem.cost.get_goal('PalmPosition', -1)

    rospy.init_node('my_vis')
    sleep(0.5)
    pub = rospy.Publisher('/exotica/CollisionShapes', MarkerArray, queue_size=1)
    sleep(0.5)

    m = Marker()
    m.header.frame_id = 'exotica/world_frame'
    m.action = Marker.ADD
    m.type = Marker.SPHERE
    m.scale.x = m.scale.y = m.scale.z = 0.1
    m.pose.position.x = palm_target[0]
    m.pose.position.y = palm_target[1]
    m.pose.position.z = palm_target[2]
    m.pose.orientation.w = 1.0
    m.color.a = 1.0
    m.color.r = 1.0
    m.id = 1
    ma = MarkerArray()
    ma.markers = [m]
        
    joint_used = [not np.allclose(problem.X[i + 38,:], 0, atol=1e-3) for i in range(problem.num_controls)]
    print('{0} joints used'.format(np.sum(joint_used)))


    plt.figure()
    plt.plot(solution, '.-')
    plt.show()

    while True:
        for t in range(problem.T - 1):
            u = solution[t,:]

            problem.update(u, t)
            problem.get_scene().get_kinematic_tree().publish_frames('exotica')

            pub.publish(ma)
            sleep(0.05)

        sleep(0.5)
