#!/usr/bin/env python

import pyxhook
import time
import rospy
from geometry_msgs.msg import Twist
from std_msgs.msg import Bool
from threading import Thread
import getpass
import sys, os, termios

FORWARD = 'forward'
BACKWARD = 'backward'
LEFT = 'left'
RIGHT = 'right'
BRAKE = 'brake'
PAUSE = 'pause'
FAST = 'fast'

default_key_map = {
    FORWARD: ['w'],
    BACKWARD: ['s'],
    LEFT: ['a'],
    RIGHT: ['d'],
    BRAKE: ['space'],
    PAUSE: ['escape'],
    FAST: ['shift_l'],
}

class MultiKeyTeleop:

    def __init__(self, linear_speed, angular_speed, speed_factor, key_map):

        self.is_paused = False

        self.linear_speed = linear_speed
        self.angular_speed = angular_speed
        self.speed_factor = speed_factor

        self.key_map = key_map
        self.pressed_keys = set()

        self.hookman = pyxhook.HookManager()
        self.hookman.HookKeyboard()
        self.hookman.KeyDown = self.on_key_down
        self.hookman.KeyUp = self.on_key_up

    def __enter__(self):
        self.hookman.start()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.hookman.cancel()

    def on_key_down(self, event):
        self.pressed_keys.add(event.Key.lower())

        if event.Key.lower() in self.key_map[PAUSE]:
            self.is_paused = not self.is_paused

    def on_key_up(self, event):
        self.pressed_keys.discard(event.Key.lower())

    def is_triggered(self, action):

        return any(key.lower() in self.pressed_keys for key in self.key_map[action])

    def twist(self):

        x = 0.0
        yaw = 0.0

        if not self.is_triggered(BRAKE):

            if self.is_triggered(FORWARD):
                x += self.linear_speed
            if self.is_triggered(BACKWARD):
                x -= self.linear_speed
            if self.is_triggered(LEFT):
                yaw += self.angular_speed
            if self.is_triggered(RIGHT):
                yaw -= self.angular_speed

        if self.is_triggered(FAST):
            x *= self.speed_factor
            yaw *= self.speed_factor

        twist = Twist()
        twist.linear.x = x
        twist.angular.z = yaw;
        return twist

    def brake(self):

        brake_msg = Bool()
        brake_msg.data = self.is_triggered(BRAKE)
        return brake_msg

def hide_inputs():

    fd = None
    tty = None
    stream = sys.stderr

    try:
        # Always try reading and writing directly on the tty first.
        fd = os.open('/dev/tty', os.O_RDWR|os.O_NOCTTY)
        tty = os.fdopen(fd, 'w+', 1)
        input = tty
        if not stream:
            stream = tty
    except EnvironmentError as e:
        # If that fails, see if stdin can be controlled.
        try:
            fd = sys.stdin.fileno()
        except:
            raise
        input = sys.stdin

    if fd is not None:
        passwd = None
        try:
            old = termios.tcgetattr(fd) # a copy to save
            new = old[:]
            new[3] &= ~termios.ECHO  # 3 == 'lflags'
            try:
                termios.tcsetattr(fd, termios.TCSADRAIN, new)
                while not rospy.is_shutdown():
                    pass
            finally:
                termios.tcsetattr(fd, termios.TCSADRAIN, old)
        except termios.error as e:
            if passwd is not None:
                # _raw_input succeeded.  The final tcsetattr failed.  Reraise
                # instead of leaving the terminal in an unknown state.
                raise
            del input, tty  # clean up unused file objects before blocking

if __name__ == '__main__':

    thread = Thread(target=hide_inputs)

    try:
        rospy.init_node("multikey_teleop", log_level=rospy.INFO)

        twist = rospy.Publisher('cmd_vel', Twist, queue_size=1)
        brake = rospy.Publisher('brake', Bool, queue_size=1)

        linear_speed = rospy.get_param('~linear_speed', 1.0)
        angular_speed = rospy.get_param('~angular_speed', 1.0)
        speed_factor = rospy.get_param('~speed_factor', 1.5)
        publish_null = rospy.get_param('~publish_null', False)
        rate = rospy.Rate(rospy.get_param('~rate', 10))
        should_hide_inputs = rospy.get_param('~hide_inputs', True)
        key_map = default_key_map
        key_map.update(rospy.get_param('~key_map'))

        with MultiKeyTeleop(linear_speed, angular_speed, speed_factor, key_map) as multikey_teleop:

            if should_hide_inputs:
                rospy.logwarn("Console input won't be displayed anymore")
                thread.start()

            while not rospy.is_shutdown():
                if not multikey_teleop.is_paused:
                    brake.publish(multikey_teleop.brake());
                    twist_msg = multikey_teleop.twist()
                    if publish_null or twist_msg.linear.x or twist_msg.angular.z:
                        twist.publish(twist_msg);
                rate.sleep()

    except rospy.ROSInterruptException:
        pass
    except rospy.exceptions.ROSInitException:
        pass
    finally:
        if thread.is_alive():
            thread.join()
