#!/usr/bin/env python

# Andy Zelenak, 2017

# Find a pose for a camera to look towards req.target_pose.
# All poses/vectors should be in the initial camera frame.
# new_cam_x should be forward towards the target, i.e. it is the lens vector
# new_cam_z should be up (as much as possible) for the image to be upright
# req.up is a vector that defines 'up.'


# Service definitions
from look_at_pose.srv import *

import geometry_msgs.msg
import numpy
import rospy
import sys
import tf  # For transformations

def handle_look_at_pose(req):

  ################################
  # Define the initial camera pose
  ################################
  global initial_cam_pose, target_pose, up_vector
  target_pose = req.target_pose
  initial_cam_pose = req.initial_cam_pose
  up_vector = req.up  # Tell us what direction is up

  ################
  # Input checking
  ################
  # Inputs should all be in the same frame as the initial camera pose.
  # Thus, I avoid dealing with tf in this script.
  # The client should handle that externally.
  if up_vector.header.frame_id != initial_cam_pose.header.frame_id:
    rospy.logerr(up_vector.header.frame_id)
    rospy.logerr(initial_cam_pose.header.frame_id)
    rospy.logerr("The vector that defines UP (req.up) should be in the same frame as the initial camera pose.")
    sys.exit()

  if target_pose.header.frame_id != initial_cam_pose.header.frame_id:
    rospy.logerr("The target pose (req.target_pose) should be in the same frame as the initial camera pose.")
    sys.exit()

  # The initial camera pose should be (0,0,0) (0,0,0,1) in its own frame
  if initial_cam_pose.pose.position.x!=0 or initial_cam_pose.pose.position.y!=0 or initial_cam_pose.pose.position.z!=0:
    rospy.logerr("The initial camera pose should be (0,0,0) (0,0,0,1) in its own frame.")
    sys.exit()

  if initial_cam_pose.pose.orientation.x!=0 or initial_cam_pose.pose.orientation.y!=0 or initial_cam_pose.pose.orientation.z!=0 or initial_cam_pose.pose.orientation.w!=1:
    rospy.logerr("The initial camera pose should be (0,0,0) (0,0,0,1) in its own frame.")
    sys.exit()

  #################################################################
  # Find the vector (v) from current camera pose to the target pose
  #################################################################
  calc_v()

  ###########################################################################
  # Add a new pose w/ the camera axis (x) aligned to v and the z-axis upright
  ###########################################################################
  calc_new_cam_unit_vectors()
  calc_rot_matrix()

  ############################
  # Return the new camera pose
  ############################
  new_cam_pose = geometry_msgs.msg.PoseStamped()
  new_cam_pose.header.frame_id = initial_cam_pose.header.frame_id

  # Origin of the new camera pose is the same -- we're just rotating
  new_cam_pose.pose.position = initial_cam_pose.pose.position

  # The rotation of the new camera pose is encapsulated in R, the rotation matrix.
  T = numpy.matrix( [ [ R.item(0,0), R.item(0,1), R.item(0,2), 0 ],
                      [ R.item(1,0), R.item(1,1), R.item(1,2), 0 ],
                      [ R.item(2,0), R.item(2,1), R.item(2,2), 0 ],
		      [ 0, 0, 0, 1]] )


  # Get the rotation as a tf.quaternion type
  scale, shear, rpy_angles, translation_vector, perspective = tf.transformations.decompose_matrix(T)
  print('Roll-pitch-yaw from original cam frame to new cam frame: ', rpy_angles)
  q = tf.transformations.quaternion_from_euler( rpy_angles[0], rpy_angles[1], rpy_angles[2] )
  #q = tf.transformations.quaternion_from_matrix(T) # <=== A faster way to do the conversion, but it's less intuitive than RPY
  new_cam_pose.pose.orientation.x = q[0]
  new_cam_pose.pose.orientation.y = q[1]
  new_cam_pose.pose.orientation.z = q[2]
  new_cam_pose.pose.orientation.w = q[3]

  ######
  # Test
  ######
  #test()

  return LookAtPoseResponse(new_cam_pose)

def calc_v():
  global v
  v = geometry_msgs.msg.Vector3()

  v.x = target_pose.pose.position.x - initial_cam_pose.pose.position.x
  v.y = target_pose.pose.position.y - initial_cam_pose.pose.position.y
  v.z = target_pose.pose.position.z - initial_cam_pose.pose.position.z

  #print('Displacement vector v: ',v)

def calc_new_cam_unit_vectors():
  #################################################
  # The new x-axis is aligned along v. Normalize it
  #################################################
  global new_cam_x, new_cam_y, new_cam_z
  new_cam_x = normalize(v)
  new_cam_y = geometry_msgs.msg.Vector3()
  new_cam_z = geometry_msgs.msg.Vector3()

  ###################################################################
  # Rotate about new_cam_x to point new_cam_z up, as much as possible
  # (In my notes, I call this "minimize theta")
  ###################################################################

  # Check for edge cases
  cross = cross_product(new_cam_x, up_vector.vector)
  if cross.x==0 and cross.y==0 and cross.z==0:
    # new_cam_x points either straight up or straight down.
    # Therefore there's no optimal alignment of new_cam_z (which we want to point up).
    # So, we give it arbitrary alignment.
    # Make new_cam_z perpendicular to ^.
    # Below is a 90deg rot about x, followed by a 90deg rot about y to produce a
    # perp. vector.
    new_cam_z.x = -new_cam_x.z
    new_cam_z.y = new_cam_x.x
    new_cam_z.z = -new_cam_x.y
    # Find new_cam_y by cross product, then check if y needs to be flipped
    new_cam_y = cross_product(new_cam_x, new_cam_z)
    print('Edge case: the target pose is either straight up or straight down')
    print('Applying an arbitrary rotation')

  # Cross product was not zero, so we can rotate the z axis to be upright-ish
  else:
    new_cam_y = cross_product(new_cam_x, up_vector.vector)
    new_cam_y = normalize(new_cam_y) # Everything has already been normalized, but it doesn't hurt to check
    new_cam_z = cross_product(new_cam_x, new_cam_y)
    # Does new_cam_z have the right sign? If not, flip it and recalculate new_cam_y, too.
    if dot_product(new_cam_z, up_vector.vector) < 0: # It is anti-parallel with UP
      new_cam_z.x = -new_cam_z.x
      new_cam_z.y = -new_cam_z.y
      new_cam_z.z = -new_cam_z.z
      new_cam_y = cross_product( new_cam_z, new_cam_x)
      #print('Flipping new_cam_y')

  # We now know the unit vectors that define the new camera pose, in the initial camera frame
  #print('New camera x-axis: ', new_cam_x)
  #print('New camera y-axis: ', new_cam_y)
  #print('New camera z-axis: ', new_cam_z)

# We already know the unit vectors that define the new camera coordinate frame.
# Use a formula to calculate the rotation matrix from initial_camera_frame to new_camera_frame.
# See https://math.stackexchange.com/questions/1125203/finding-rotation-axis-and-angle-to-align-two-3d-vector-bases?rq=1
def calc_rot_matrix():
  
  # vector basis for new_cam_frame. In the initial_cam_frame, these are the unit vectors of new_cam_frame
  d = numpy.matrix( [new_cam_x.x, new_cam_x.y, new_cam_x.z] )
  e = numpy.matrix( [new_cam_y.x, new_cam_y.y, new_cam_y.z] )
  f = numpy.matrix( [new_cam_z.x, new_cam_z.y, new_cam_z.z] )

  # vector basis for initial camera frame
  a = numpy.matrix('1 0 0')
  b = numpy.matrix('0 1 0')
  c = numpy.matrix('0 0 1')

  # Dyads
  d_a = numpy.matrix( [ [ d.item(0)*a.item(0), d.item(0)*a.item(1), d.item(0)*a.item(2) ],
                      [ d.item(1)*a.item(0), d.item(1)*a.item(1) , d.item(1)*a.item(2) ],
                      [ d.item(2)*a.item(0), d.item(2)*a.item(1) , d.item(2)*a.item(2) ]] )

  e_b = numpy.matrix( [ [ e.item(0)*b.item(0), e.item(0)*b.item(1), e.item(0)*b.item(2) ],
                      [ e.item(1)*b.item(0), e.item(1)*b.item(1) , e.item(1)*b.item(2) ],
                      [ e.item(2)*b.item(0), e.item(2)*b.item(1) , e.item(2)*b.item(2) ]] )

  f_c = numpy.matrix( [ [ f.item(0)*c.item(0), f.item(0)*c.item(1), f.item(0)*c.item(2) ],
                      [ f.item(1)*c.item(0), f.item(1)*c.item(1) , f.item(1)*c.item(2) ],
                      [ f.item(2)*c.item(0), f.item(2)*c.item(1) , f.item(2)*c.item(2) ]] )

  global R  # The rotation matrix to take a vector from the new camera frame to the initial camera frame
  R = d_a + e_b + f_c

  return

def cross_product( u, v):
  cross = geometry_msgs.msg.Vector3()
  cross.x = u.y*v.z-v.y*u.z
  cross.y = v.x*u.z-u.x*v.z
  cross.z = u.x*v.y-u.y*v.x
  return cross

def dot_product( u, v):
  dot = u.x*v.x+u.y*v.y+u.z*v.z
  return dot

def normalize( input_vector ):
  mag = (input_vector.x**2+input_vector.y**2+input_vector.z**2)**0.5
  unit_vector = geometry_msgs.msg.Vector3()
  unit_vector.x = input_vector.x/mag
  unit_vector.y = input_vector.y/mag
  unit_vector.z = input_vector.z/mag

  return unit_vector

def test():
  # Test the rotation matrix
  test_vector = numpy.matrix( [[1.13], [-0.17], [-0.27]] )
  print('')
  print('(1.13,-0.17,-0.27) in initial_cam_frame, rotated to new_cam_frame: ', numpy.linalg.inv(R)*test_vector )
  print('When running test_client, should be (1.174, 0, 0)')
  print('')

  test_vector = numpy.matrix( [[1.174], [0], [0]] )
  print('')
  print('(1.174,0,0) in new_cam_frame, rotated to initial_cam_frame: ', R*test_vector )
  print('When running test_client, should be (1.13, -0.17, -0.27)')
  print('')

def look_at_pose_server():
  rospy.init_node('look_at_pose')
  s = rospy.Service('look_at_pose', LookAtPose, handle_look_at_pose)
  rospy.spin()

if __name__ == '__main__':
  look_at_pose_server()
