#!/usr/bin/python
import rospy
from rospkg import RosPack
import yaml
import sys
import os
from python_qt_binding.QtWidgets import QApplication, QWidget, QVBoxLayout,QHBoxLayout,QGridLayout, QLabel, QSlider, QLineEdit, QPushButton, QCheckBox
from python_qt_binding.QtCore import Signal, Qt,  pyqtSlot
from python_qt_binding.QtGui import QFont
from threading import Thread
import signal
from geometry_msgs.msg import Quaternion
from functools import reduce
from numpy import pi
from tf.transformations import quaternion_from_euler

font = QFont("Helvetica", 9, QFont.Bold)
topic_font = QFont("Helvetica", 10, QFont.Bold)
rospack = RosPack()

def rsetattr(obj, attr, val):
    pre, _, post = attr.rpartition('.')
    if pre:
        return setattr(rgetattr(obj, pre), post, val)
    # convert to expected field type
    return setattr(obj, post, type(getattr(obj, post))(val))

def rgetattr(obj, attr, *args):
    if attr == '':
        return obj
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)
    return reduce(_getattr, [obj] + attr.split('.'))

def split_field(key):
    if '.' in key:
        return key.rsplit('.', 1)
    return '', key

def isRPY(key, msg):
    is_axis = False
    field, axis = split_field(key)
    if axis in ('roll', 'pitch', 'yaw'):
        return type(rgetattr(msg, field)) == Quaternion
    return False

def robust_eval(val):
    if type(val) in (list,tuple):
        return [robust_eval(v) for v in val]
    if type(val) == str:
        val_expr = val.strip().lower()
        
        # check for  Pi fractions 
        for sign, sign_rep in ((1, ''), (-1, '-')):
            if val_expr == sign_rep + 'pi':
                return sign*pi
            
            for denom in range(2, 13):
                if val_expr == sign_rep + 'pi/' + str(denom):
                    return sign * pi/denom
        return val
    
    return float(val)

def key_tag(topic, key):
    return topic + '/' + key

def get_interface(pkg, interface, name):
    pkg = __import__(pkg, globals(), locals(), [interface])
    return getattr(getattr(pkg,interface), name)

def get_type(msg, key):
    if '[' in key:
        # TODO parse file to get what is inside this vector
        return 'double'

    if '.' in key:
        main, nested = key.split('.',1)
        return get_type(getattr(msg, main), nested)
    
    if not hasattr(msg, key) and key in ('roll', 'pitch', 'yaw'):
        return 'double'
    out = str(type(getattr(msg, key)))
    if 'double' in out or 'float' in out:
        return 'double'
    if 'int' in out:
        return 'int'
    return 'bool'

class Control:
    def __init__(self, msg, info):
        
        self.type = get_type(msg, info['to'])
        
        if self.type in ('int', 'double'):
            
            self.min = self.conv(robust_eval(info['min']))
            self.max = self.conv(robust_eval(info['max']))                        
            self.range = 1000 if self.type == 'double' else self.max-self.min
            self.default = self.conv(robust_eval(info['default'])) if 'default' in info else self.conv((self.max+self.min)/2)
            self.value = self.default
        
        else:
            self.range = None
            self.default = self.conv(robust_eval(info['default'])) if 'default' in info else False
            self.value = self.default
                            
    def conv(self, v):
        if self.type == 'int':
            return int(v)
        if self.type == 'double':
            return float(v)
        if self.type == 'bool':
            return bool(v)
        print('Unknown type "{}"'.format(self.type))
        return None
    
    def reset(self):
        if self.range is None:
            self.box.setChecked(self.default)
            return
        elif self.type == 'double':
            slider_val = (self.default-self.min)/(self.max-self.min)*self.range            
        else:
            slider_val = self.default-self.min
        self.slider.setValue(int(slider_val))            
        
    def refresh(self):
        if self.range is None:
            self.value = self.box.isChecked()
        else:
            slider_val = self.slider.value()
            if self.type == 'double':
                self.value = self.min + slider_val*(self.max - self.min)/self.range
                self.display.setText(f'{round(self.value, 2)}')
            else:
                self.value = self.min + slider_val
                self.display.setText(f'{self.value}')
    
    def init_slider(self):
        
        self.display = QLineEdit(f'{round(self.value, 2)}')
        self.display.setAlignment(Qt.AlignRight)
        self.display.setFont(font)
        self.display.setReadOnly(True)
                                    
        self.slider = QSlider(Qt.Horizontal)
        self.slider.setFont(font)
        self.slider.setRange(0, self.range)
        self.reset()
        
    def init_box(self):
        
        self.box = QCheckBox()
        self.reset()

class Publisher:
    def __init__(self, topic, msg, info, req = None):
        self.topic = topic
        
        if req is None:
            self.msg = msg()
            self.pub = rospy.Publisher(topic, msg, queue_size=1)
            self.client = None
        else:
            self.msg = req()
            self.pub = None
            print('Waiting for service ' + topic)
            rospy.wait_for_service(topic)
            self.client = rospy.ServiceProxy(topic, msg)
                        
        self.rpy = {}
            
        # init map from GUI to message
        self.map = {}
        to_remove = []
        for key in info:
            tag = key_tag(topic, key)
            if type(info[key]) == dict:
                if isRPY(info[key]['to'], self.msg):
                    field, axis = split_field(info[key]['to'])
                    if field not in self.rpy:
                        self.rpy[field] = {'roll': 0, 'pitch': 0, 'yaw': 0}
                self.map[tag] = info[key]['to']
            else:
                if key != 'type':
                    # init non-zero defaults
                    if isRPY(key, self.msg):
                        field, axis = split_field(key)
                        if field not in self.rpy:
                            self.rpy[field] = {'roll': 0, 'pitch': 0, 'yaw': 0}
                        self.rpy[field][axis]  = robust_eval(info[key])
                    else:
                        self.write(key, robust_eval(info[key]))
                to_remove.append(key)
        for rm in to_remove:
            info.pop(rm)
                           
    def write(self, key, val):
        
        field, axis = split_field(key)
        if field in self.rpy:
            self.rpy[field][axis] = val
        elif '[' in key:        
            field, idx = key[:-1].split('[')
            idx = int(idx)
            current = rgetattr(self.msg, field)
            if len(current) <= idx:
                if 'name' in field:
                    current += ['' for i in range(idx +1 - len(current))]
                else:
                    current += [0 for i in range(idx +1 - len(current))]
            current[idx] = val
            rsetattr(self.msg, field, current)
        else:
            rsetattr(self.msg, key, val)
        
    def update(self, values):
        for tag in self.map:
            self.write(self.map[tag], values[tag].value)
        # write RPY's to Quaternions
        for field in self.rpy:
            q = quaternion_from_euler(self.rpy[field]['roll'],self.rpy[field]['pitch'], self.rpy[field]['yaw'])
            for idx, axis in enumerate(('x','y','z','w')):
                if field:
                    rsetattr(self.msg, field + '.' + axis, q[idx])
                else:
                    setattr(self.msg, axis, q[idx])
        # update time if stamped msg
        if hasattr(self.msg, "header"):
            self.write('header.stamp', rospy.Time.now())
            
        if self.pub is not None:
            self.pub.publish(self.msg)
        else:
            # service call, dont care about the result
            self.client(self.msg)

class SliderPublisher(QWidget):
    def __init__(self, content):
        super(SliderPublisher, self).__init__()
        
        content = content.replace('\t', '    ')
        
        # get message types
        self.publishers = {}
        self.controls = {}
        pkgs = []
        
        # to keep track of key ordering in the yaml file
        order = []
         
        for topic, info in yaml.safe_load(content).items():
            
            msg = info.pop('type')
            if msg.count('/') == 2:
                pkg,interface,msg = msg.split('/')
            else:
                pkg,msg = msg.split('/')
                interface = None
            
            if interface is None:
                # guess msg or srv
                here = {}
                share = rospack.get_path(pkg)                
                for key in ('msg','srv'):
                    here[key] = os.path.exists(f'{share}/{key}/{msg}.{key}')
                
                if here['msg'] and here['srv']:
                    print(f'Specify srv/msg in the yaml file: both files exist for {msg}')
                    sys.exit(0)
                interface = 'msg' if here['msg'] else 'srv'
                
            if interface == 'srv':
                req = get_interface(pkg, 'srv', msg+'Request')
            else:
                req = None
                
            msg = get_interface(pkg, interface, msg)

            self.publishers[topic] = Publisher(topic, msg, info, req)
            order.append((topic,[]))
            for key in info:
                tag = key_tag(topic,key)
                # identify key type -> slider (continuous or discrete) / checkbox
                self.controls[tag] = Control(self.publishers[topic].msg, info[key])
                
                order[-1][1].append((content.find(' ' + key + ':'), key))
                    
            order[-1][1].sort()
        order.sort(key = lambda x: x[1][0][0])
        
        # build sliders - thanks joint_state_publisher
        sliderUpdateTrigger = Signal()
        self.vlayout = QVBoxLayout(self)
        self.gridlayout = QGridLayout()
        font = QFont("Helvetica", 9, QFont.Bold)
        topic_font = QFont("Helvetica", 10, QFont.Bold)
        
        y = 0
        for topic,keys in order:
            topic_layout = QVBoxLayout()
            label = QLabel(topic)
            label.setFont(topic_font)
            topic_layout.addWidget(label)
            self.gridlayout.addLayout(topic_layout, *(y, 0))
            y+=1
            for idx,key in keys:
                tag = key_tag(topic,key)
                key_layout = QVBoxLayout()
                row_layout = QHBoxLayout()
                label = QLabel(key)
                label.setFont(font)
                row_layout.addWidget(label)
                
                if self.controls[tag].range is not None:
                                
                    self.controls[tag].init_slider()
                
                    row_layout.addWidget(self.controls[tag].display)
                    key_layout.addLayout(row_layout)
                    key_layout.addWidget(self.controls[tag].slider)
                    self.controls[tag].slider.valueChanged.connect(self.onValueChanged)
                    
                else:
                    self.controls[tag].init_box()
                    row_layout.addWidget(self.controls[tag].box)
                    key_layout.addLayout(row_layout)
                    self.controls[tag].box.clicked.connect(self.onValueChanged)
                    
                self.gridlayout.addLayout(key_layout, *(y,0))
                y+=1
            
        self.vlayout.addLayout(self.gridlayout)            
        
        self.reset_button = QPushButton('Reset', self)
        self.reset_button.clicked.connect(self.reset)
        self.vlayout.addWidget(self.reset_button)
        
    def onValueChanged(self, event):
        for control in self.controls.values():
            control.refresh()
            
    def reset(self, event):
        for control in self.controls.values():
            control.reset()            
        self.onValueChanged(event)                
                    
    def loop(self):        
        rate = rospy.Rate(10)
        while not rospy.is_shutdown():
            for pub in self.publishers:
                self.publishers[pub].update(self.controls)
            rate.sleep()
            
if __name__ == "__main__":
    
    rospy.init_node('slider_publisher')
    
    # read passed param
    filename = sys.argv[-1]
    if not os.path.exists(filename):
        if not rospy.has_param("~file"):
            rospy.logerr("Pass a yaml file (~file param or argument)")
            sys.exit(0)
        filename = rospy.get_param("~file")
        
    
    # also get order from file
    with open(filename) as f:
        content = f.read()
                        
    # build GUI
    title = rospy.get_name().split('/')[-1]
    app = QApplication([title])    
    sp = SliderPublisher(content)
    #pause
    Thread(target=sp.loop).start()
    signal.signal(signal.SIGINT, signal.SIG_DFL)
    sp.show()
    sys.exit(app.exec_())
    
