#!/usr/bin/env python
# Copyright (c) 2016-2019 The UUV Simulator Authors.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import rospy
import numpy as np
from uuv_auv_control_allocator.msg import AUVCommand
import os
import casadi
from uuv_auv_actuator_interface import ActuatorManager


class AUVControlAllocator(ActuatorManager):
    def __init__(self):
        super(AUVControlAllocator, self).__init__()
        
        # Retrieve the output file path to store the CAM
        # matrix for future use
        self.output_dir = None
        if rospy.has_param('~output_dir'):
            self.output_dir = rospy.get_param('~output_dir')
            if not os.path.isdir(self.output_dir):
                raise rospy.ROSException(
                    'Invalid output directory, output_dir=' + self.output_dir)
            rospy.loginfo('output_dir=' + str(self.output_dir))

        self.update_rate = rospy.get_param('~update_rate', 10)

        self.timeout = rospy.get_param('~timeout', -1)

        if self.update_rate <= 0:
            raise rospy.ROSException('Update rate must be a positive number')

        # Control allocation matrix for thruster and fin lift forces
        self.u = casadi.SX.sym('u')
        self.thrust = casadi.SX.sym('thrust')
        self.delta = None
        self.output_tau = casadi.SX.sym('tau', 6, 1)
        self.actuator_model = 0

        self.input_surge_speed = 0.0
        self.input_tau = np.zeros(6)        

        if not self.init():
            raise rospy.ROSException('No thruster and/or fins found')
      
        self.sub = rospy.Subscriber('control_input', AUVCommand, self.control_callback)

        rate = rospy.Rate(self.update_rate)

        self.last_update = rospy.Time.now()

        while not rospy.is_shutdown():
            
            if self.timeout > -1:
                if rospy.Time.now() - self.last_update > self.timeout:
                    self.input_surge_speed = 0.0
                    self.input_tau = np.zeros(6)            
            output = self.allocate(self.input_tau, self.input_surge_speed)            
            self.publish_commands(output)
            rate.sleep()

    def init(self):
        """Calculate the control allocation matrix, if one is not given."""
                
        # Build the casadi model
        self.delta = casadi.SX.sym('delta', self.n_fins)
        actuator_model = self.thruster.tam_column.reshape((6, 1)) * self.thrust
        for i in self.fins:
            f_lift = (0.5 * self.fin_config['fluid_density'] * 
                self.fin_config['lift_coefficient'] * self.fin_config['fin_area'] *  
                self.delta[i] * self.u**2)
        
            f_lift = f_lift * self.fins[i].lift_vector
            t_lift = casadi.cross(self.fins[i].pos, f_lift)

            actuator_model += casadi.vertcat(f_lift, t_lift)

        self.cost_function = casadi.norm_2(self.output_tau - actuator_model)
        
        return True

    def allocate(self, tau, u):        
        output = np.zeros(self.n_fins + 1)

        if self.input_surge_speed == 0:
            return output

        if np.linalg.norm(tau) == np.abs(tau[0]):
            output[0] = self.thruster.tam_column[0] * tau[0]
        else:
            model = casadi.substitute(
                self.cost_function, 
                casadi.vertcat(self.output_tau, self.u), 
                tau.tolist() + [u])
                
            nlp = dict(
                x=casadi.vertcat(self.thrust, self.delta),
                f=model)

            opts = {'verbose':False, 'print_time':False, 'ipopt.print_level':0}

            solver = casadi.nlpsol('solver', 'ipopt', nlp, opts)
            sol = solver(
                lbx=[-self.thruster_config['max_thrust']] + [self.fin_lower_limit for _ in range(self.n_fins)],
                ubx=[self.thruster_config['max_thrust']] + [self.fin_upper_limit for _ in range(self.n_fins)])
            
            for i in range(self.n_fins + 1):
                output[i] = sol['x'][i]
        return output

    def control_callback(self, msg):
        self.input_surge_speed = msg.surge_speed
        self.input_tau = np.array([msg.command.force.x, 
                                   msg.command.force.y,
                                   msg.command.force.z,
                                   msg.command.torque.x, 
                                   msg.command.torque.y,
                                   msg.command.torque.z])  

        if 'ned' in msg.header.frame_id:
            self.input_tau[0:3] = np.dot(self.base_link_ned_to_enu, self.input_tau[0:3])
            self.input_tau[4::] = np.dot(self.base_link_ned_to_enu, self.input_tau[4::])
        
        self.last_update = rospy.Time.now()
            

if __name__ == '__main__':
    rospy.init_node('auv_control_allocator')

    try:
        node = AUVControlAllocator()
        rospy.spin()
    except rospy.ROSInterruptException:
        print('AUVControlAllocator::Exception')
    print('Leaving AUVControlAllocator')

        
