#!/usr/bin/env python

# ##########################################################################
# Copyright (c) 2014 Shadow Robot Company Ltd.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 2 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along
# with this program. If not, see <http://www.gnu.org/licenses/>.
#
# ###########################################################################

import rospy, actionlib
from math import degrees, radians, sqrt

from dynamic_reconfigure.server import Server
from sr_hand.shadowhand_ros import ShadowHand_ROS
from sr_grasp.cfg import GraspConfig
from sr_robot_msgs.msg import ShadowPST, ShadowContactStateStamped, GraspAction
from gazebo_msgs.msg import ContactsState

class InterpolatedJoint(object):
    def __init__(self, name, start=0, end=0, active=True):
        self.name    = name
        self.start   = start
        self.end     = end
        self.active  = active
        self.i       = 0
        self.current = self.start

    def inc(self, amt=1):
        self.i += amt
        if self.at_target():
            self.i = 100
            self.active = False
            self.current = self.end
        else:
            self.current = self.start + (self.end-self.start)*(self.i/100.0)
        return self.current

    def at_target(self):
        if self.i >= 100: return True
        return False

class GraspNode(object):
    def __init__(self, ):
        """
        @brief Construct a new GraspNode, setting up it's publishers,
        subscribers, etc.
        """

        # Config
        self.interpolation_rate = 20 # hz (percent per second)
        self.force_threshold = 0.3
        self.squeeze = 4
        self.first_contact_only = False
        # Current grasp
        self.grasp = None
        # Dict of target positions in degrees keyed on joint name
        self.targets = None
        # Key is finger, value force, normalised to a positive value for all sensor types
        self.tactiles = {}
        self.pst_bias = None

        # ROS setup
        #
        self.config_server = Server(GraspConfig, self.config_cb)

        # We need to wait for the controllers to be loaded before we call the ShadowHand_ROS library
        # otherwise some of the joints might be missing.
        # We need to find a better way to determine when that happens (provided that we don't know
        # which controllers will exist and which will not).
        # For the time being we only wait for a while ( 4 sec)
        rospy.loginfo("Waiting for hand to pop up and hand controllers to be launched...")
        rospy.sleep(4.0)
        # Hand access. Do early (before subscribing, advertising actions etc),
        # as it blocks for a while detecting hands.
        rospy.loginfo("Detecting hand...")
        self.hand = ShadowHand_ROS()
        t = "Unknown"
        t = self.hand.hand_type
        rospy.loginfo("Found %s hand with %s tactile."%(t, self.hand.get_tactile_type()))

        # Action server for running this node
        self.grasp_server = actionlib.SimpleActionServer(
                'grasp', GraspAction, auto_start = False)
        self.grasp_server.register_goal_callback(self.grasp_goal_cb)
        self.grasp_server.register_preempt_callback(self.grasp_preempt_cb)
        self.grasp_server.start()

        rospy.loginfo("Looking for tactile sensors")
        # Shadow style tactile sensor
        self.sub_tactile_shadow = {}
        found_fancy = 0
        finger_map = {
            'rf': 'rfdistal/ContactState',
            'th': 'thdistal/ContactState',
            'ff': 'ffdistal/ContactState',
            'mf': 'mfdistal/ContactState',
            'lf': 'lfdistal/ContactState',
        }
        for finger, topic in finger_map.items():
            try:
                rospy.wait_for_message(topic, ShadowContactStateStamped, timeout = 0.2)
                rospy.loginfo("Found fancy tactile on %s finger"%finger)
                found_fancy += 1
                self.sub_tactile_shadow[finger] = rospy.Subscriber(topic, ShadowContactStateStamped, self.shadow_tactile_cb, finger)
            except Exception as e:
                rospy.logwarn(e)

        # We didn't find nano type sensor so try fallbacks
        if found_fancy == 0:
            if self.hand.hand_type == "gazebo":
                rospy.loginfo("Using gazebo tactiles")
                self.sub_tactile_gazebo = {}
                for finger in ['lf', 'rf', 'mf', 'ff', 'th']:
                    self.sub_tactile_gazebo[finger] = rospy.Subscriber(
                            'contacts/'+finger+'/distal', ContactsState,
                            self.gazebo_tactile_cb, finger)
            elif self.hand.get_tactile_type() == 'PST':
                rospy.loginfo("Using PST tactiles")
                self.sub_pst = rospy.Subscriber('tactile', ShadowPST, self.pst_tactile_cb)
            else:
                rospy.logerr("Don't know how to handle %s tactiles!"%(self.hand.get_tactile_type()))

        rospy.loginfo("Ready to grasp")

    def config_cb(self, config, level):
        """
        @brief Handles incoming dynamic reconfigure requests, updating the
        objects attributes with the new config.
        @param config Config object, a GraspConfig
        """
        self.interpolation_rate = config.interpolation_rate
        self.force_threshold    = config.force_threshold
        self.squeeze            = config.squeeze
        self.first_contact_only = config.first_contact_only
        return config

    def gazebo_tactile_cb(self, msg, finger):
        """Fills self.tactiles with force value."""
        if len(msg.states) == 0:
            self.tactiles[finger] = 0
            return
        f = msg.states[0].total_wrench.force
        self.tactiles[finger] = sqrt(f.x**2 + f.y**2 + f.z**2)

    def shadow_tactile_cb(self, msg, finger):
        """Fills self.tactiles with force value."""
        # Force is negative as it is into the finger, so normalise to positive
        self.tactiles[finger] = (-1.0*msg.Fnormal)

    def pst_tactile_cb(self, msg):
        if self.pst_bias is None:
            rospy.loginfo("Getting pst bias")
            self.pst_bias = {}
            self.pst_bias['ff'] = msg.pressure[0]
            self.pst_bias['mf'] = msg.pressure[1]
            self.pst_bias['rf'] = msg.pressure[2]
            self.pst_bias['lf'] = msg.pressure[3]
            self.pst_bias['th'] = msg.pressure[4]
        # TODO Clamp to 0
        self.tactiles['ff'] = msg.pressure[0] - self.pst_bias['ff']
        self.tactiles['mf'] = msg.pressure[1] - self.pst_bias['mf']
        self.tactiles['rf'] = msg.pressure[2] - self.pst_bias['rf']
        self.tactiles['lf'] = msg.pressure[3] - self.pst_bias['lf']
        self.tactiles['th'] = msg.pressure[4] - self.pst_bias['th']
        #rospy.loginfo(self.tactiles)

    def grasp_goal_cb(self):
        rospy.loginfo("Grasp Goal received")
        goal = self.grasp_server.accept_new_goal()
        self.grasp = goal.grasp
        posture = self.grasp.grasp_posture
        if goal.pre_grasp:
            posture = self.grasp.pre_grasp_posture
        if len(posture.joint_names) == 0:
            rospy.logerr("Empty grasp")
            self.grasp_server.set_aborted(text="Empty grasp")
            return
        self.run_posture(posture)

        rospy.loginfo("Grasp finished")
        self.grasp_server.set_succeeded()

    def grasp_preempt_cb(self):
        rospy.loginfo("Grasp preempted, stopping")
        self.grasp_server.set_preempted()

    def joint_trajectory_to_position(self, traj):
        """
        Takes a trajectory_msgs/JointTrajectory, returns dict of
        the joint name/position (angle). Converts radians to degrees
        for easy use with ShadowHand_ROS.
        """
        joints = {}
        for i in range(len(traj.joint_names)):
            joints[traj.joint_names[i]] = degrees(traj.points[0].positions[i])
        return joints

    def fix_joints(self, targets):
        """
        For a dict of joint/value mappings fix up J0 if the dict uses J1
        and J2 instead, by adding J0 as the sum of J1 and J2, which are removed.

        Remove, with warnings, unknown joints.

        Remove the wrist joints, we don't consider them part of the grasp, they
        would be used to get the hand in the right pose to start the grasp.
        """
        # Fixup J0
        for j in ['LFJ', 'RFJ', 'MFJ', 'FFJ']:
            if j+"1" in targets and j+"2" in targets:
                targets[j+"0"] = targets[j+"1"] + targets[j+"2"]
                del targets[j+"1"]
                del targets[j+"2"]

        # Remove unknown joints
        jnames = [j.name for j in self.hand.allJoints]
        for j in targets.keys():
            if j not in jnames:
                rospy.logwarn("Unknown joint %s, ignoring."%j)
                del targets[j]

    def run_posture(self, posture):
        self.targets = self.joint_trajectory_to_position(posture)
        self.fix_joints(self.targets)
        self.start_position = self.hand.read_all_current_positions()

        if self.start_position == None:
            rospy.logerr("Unable to read current hand position. Check if position controllers are running.")

        unknown = []
        for name in self.targets:
            if name not in self.start_position:
                unknown.append(name)
        if unknown:
            rospy.logerr("Unknown joints in targets (they will be ignored): %s"%unknown)

        # Convert the targets into objects
        self.joints = []
        for name, target in self.targets.items():
            joint = InterpolatedJoint(name, self.start_position[name], target)
            self.joints.append(joint)

        # Run the iterpolated motion, testing for contact as we go.
        # IDEA: Run the motion on a thread and use the tactile callbacks to
        # trigger the stop.
        r = rospy.Rate(self.interpolation_rate)
        active_joints = len(self.joints)
        while active_joints > 0:
            # Test for contact
            if not self.grasp_server.current_goal.get_goal().pre_grasp:
                for finger, value in self.tactiles.items():
                    if value > self.force_threshold:
                        self.stop_finger(finger)
                    elif not self.first_contact_only:
                        self.start_finger(finger)
            # Move the active joints
            active_joints = 0
            update = {}
            for j in self.joints:
                if not j.active:
                    continue
                j.inc(1)
                if j.active:
                    active_joints += 1
                update[j.name] = j.current
            self.hand.sendupdate_from_dict(update)

            r.sleep()

        # Little squeeze to hold firm
        rospy.loginfo("Squeeze")
        update = {}
        for j in self.joints:
            j.inc(self.squeeze)
            update[j.name] = j.current
        self.hand.sendupdate_from_dict(update)

    def stop_finger(self, finger):
        # TODO Set all joints in finger to current pos?
        prefix = finger.upper()
        has_active = False
        for j in self.joints:
            if j.name.startswith(prefix):
                if j.active:
                    has_active = True
                    j.active = False
        if has_active:
            rospy.loginfo("Stopping finger %s"%finger)

    def start_finger(self, finger):
        prefix = finger.upper()
        changed = False
        for j in self.joints:
            if j.name.startswith(prefix):
                if not j.active and not j.at_target():
                    changed = True
                    j.active = True
        if changed:
            rospy.loginfo("Starting finger %s"%finger)


if __name__ == "__main__":
    try:
        rospy.init_node("grasp")
        node = GraspNode()
        rospy.spin()
    except rospy.ROSInterruptException:
        pass
