#!/usr/bin/env python

import serial
import time
import binascii

import rospy
from ainstein_radar_msgs.msg import RadarTargetArray
from ainstein_radar_msgs.msg import RadarTarget

SRD_D1_SINGLE_HEADER = 0x55AA
SRD_D1_SINGLE_TAIL = 0x5AA5
SRD_D1_PACKETSIZE = 15

class SRD_D1_MULTIPLE_UART(object):
    def __init__(self, uart, frame_id):
        self.uart = uart
        self.pub = rospy.Publisher( '~targets/raw', RadarTargetArray, queue_size=10 )
        self.radar_data_msg_raw = RadarTargetArray()
        self.radar_data_msg_raw.header.stamp = rospy.Time(0)
        rospy.init_node( 'srd_d1_node')
        self.radar_data_msg_raw.header.frame_id = frame_id
        
    def data_receive_srd_d1_multiple(self):
        while not rospy.is_shutdown():
            whole_packet = [] # fixed packet length
            Header_Bytes = [0, 0] # two Bytes in header
            Tail_Bytes = [0, 0]
            rx_checksum = 0
            if ( True ):
                # print(self.uart.read())
                try:
                    while True:
                        Header_Bytes[0] = ord(self.uart.read())
                        Header_Bytes[1] = ord(self.uart.read())
                        if Header_Bytes[0] == (SRD_D1_SINGLE_HEADER>>8) & 0xff and Header_Bytes[1] == (SRD_D1_SINGLE_HEADER) & 0xff:
                            while True:
                                Tail_Bytes[0] = ord(self.uart.read())
                                whole_packet.append(Tail_Bytes[0])
                                # Tail_Bytes[1] = ord(self.uart.read())
                                # whole_packet.append(Tail_Bytes[1])
                                if Tail_Bytes[0] == (SRD_D1_SINGLE_TAIL>>8) & 0xff:
                                    Tail_Bytes[1] = ord(self.uart.read())
                                    whole_packet.append(Tail_Bytes[1])
                                    if Tail_Bytes[1] == (SRD_D1_SINGLE_TAIL) & 0xff:
                                        break
                                    else:
                                        continue
                                else:
                                    Tail_Bytes[1] = ord(self.uart.read())
                                    whole_packet.append(Tail_Bytes[1])
                                    if Tail_Bytes[1] == (SRD_D1_SINGLE_TAIL>>8) & 0xff:
                                        Tail_r = ord(self.uart.read())
                                        whole_packet.append(Tail_r)
                                        if Tail_r == SRD_D1_SINGLE_TAIL & 0xff:
                                            break
                                        else:
                                            continue
                                    else:
                                        continue
                                    continue
                            break  ## find right header bytes, quit loop
                        else:
                            continue
                            print "Header Missed!!!"

                    whole_packet = Header_Bytes + whole_packet
                    self.data_parse_srd_d1_multiple(whole_packet, self.radar_data_msg_raw)
                    if( len( self.radar_data_msg_raw.targets ) > 0 ):
                        self.pub.publish( self.radar_data_msg_raw )

                except Exception, e:

                    print e.message
            else:
                return


    def data_parse_srd_d1_multiple(self, frame, msg):
            checksum = frame[len(frame) - 3]
            sum_calc = 0
            for i in range(2, len(frame) - 3):
                sum_calc += (frame[i])

            if (sum_calc & 0xFF) == checksum:
                # targets = []
                # device_code = (frame[2])
                target_number = (frame[len(frame) - 4])
                msg.targets = []
                idx = 3

                for i in range(0, target_number):
                    target = RadarTarget()
                    target.target_id = i
                    target.snr = 0.0
                    ## distance of object #i
                    dis = (frame[idx]) * 256 + (frame[idx + 1])
                    dis /= 100.0
                    idx += 2
                    target.range = dis
                    ## velocity of object #i
                    vel = (frame[idx])
                    if vel > 127:
                        vel = vel - 256
                    vel /= 10.0
                    idx += 1
                    target.speed = vel
                    ## azimuth angle of object #i
                    agl_h = (frame[idx]) * 256 + (frame[idx + 3])
                    if agl_h > 32767:
                        agl_h = agl_h - 65536
                    agl_h /= 10.0
                    idx += 1
                    target.azimuth = agl_h
                    ## elevation angle of object #i
                    agl_v = (frame[idx])
                    if agl_v > 127:
                        agl_v = agl_v - 256
                    idx += 1
                    target.elevation = agl_v
                    ## serial port receiver flag of object #i
                    serial_port_flag = (frame[idx])
                    idx += 2
                    ## ID of object #i
                    id_number = (frame[idx])
                    idx += 1
                    
                    self.radar_data_msg_raw.targets.append( target )
            else:
                print u'err frame: checksum err'

    def data_receive_srd_d1_us_launch(self):
        while not rospy.is_shutdown():
            whole_packet = [] # fixed packet length
            Header_Bytes = [0, 0] # two Bytes in header
            Tail_Bytes = [0, 0]
            rx_checksum = 0
            if (True):
                # print ord(self.uart.read())
                try:
                    while True:
                        Header_Bytes[0] = ord(self.uart.read())
                        if Header_Bytes[0] == (SRD_D1_SINGLE_HEADER>>8) & 0xff:
                            Header_Bytes[1] = ord(self.uart.read())
                            if Header_Bytes[1] == (SRD_D1_SINGLE_HEADER) & 0xff:
                                while True:
                                    Tail_Bytes[0] = ord(self.uart.read())
                                    whole_packet.append(Tail_Bytes[0])
                                    if Tail_Bytes[0] == (SRD_D1_SINGLE_TAIL>>8) & 0xff:
                                        Tail_Bytes[1] = ord(self.uart.read())
                                        whole_packet.append(Tail_Bytes[1])
                                        if Tail_Bytes[1] == (SRD_D1_SINGLE_TAIL) & 0xff:
                                            break    
                                        else:
                                            continue
                                    else:
                                        Tail_Bytes[1] = ord(self.uart.read())
                                        whole_packet.append(Tail_Bytes[1])
                                        if Tail_Bytes[1] == (SRD_D1_SINGLE_TAIL>>8) & 0xff:
                                            Tail_r = ord(self.uart.read())
                                            whole_packet.append(Tail_r)
                                            if Tail_r == SRD_D1_SINGLE_TAIL & 0xff:
                                                break
                                            else:
                                                continue
                                        else:
                                            continue
                                        continue
                                break  ## find right header bytes, quit loop
                        else:
                            continue

                    whole_packet = Header_Bytes + whole_packet
                    self.data_parse_srd_d1_us_launch(whole_packet, self.radar_data_msg_raw)
                    if( len( self.radar_data_msg_raw.targets ) > 0 ):
                        self.pub.publish( self.radar_data_msg_raw )

                except Exception, e:
                    print e.message
            else:
                return

    def data_parse_srd_d1_us_launch(self, frame, msg):
            ## Checksum Byte in transmitter data packet
            checksum = frame[len(frame) - 3]
            ## Calculate checksum Byte in receiver data packet
            sum_calc = 0
            for i in range(2, len(frame) - 3):
                sum_calc += (frame[i])

            if (sum_calc & 0xFF) == checksum:
                target_number = (frame[len(frame) - 4])
                device_id = frame[2]
                firmware_version = frame[3]
                msg.targets = []
                idx = 4
                for i in range(0, target_number):
                    target = RadarTarget()
                    target.target_id = i
                    ## distance of object #i
                    dis = frame[idx] * 256 + frame[idx + 1]
                    dis /= 100.0
                    idx += 2
                    target.range = dis
                    ## velocity of object #i
                    vel = frame[idx] * 256 + frame[idx + 1]
                    if vel > 32767:
                        vel = vel - 65536
                    vel /= 10.0
                    idx += 2
                    target.speed = vel
                    ## azimuth angle of object #i
                    agl_h = frame[idx] * 256 + frame[idx + 1]
                    if agl_h > 32767:
                        agl_h = agl_h - 65536
                    agl_h /= 10.0
                    idx += 2
                    target.azimuth = agl_h
                    ## elevation angle of object #i
                    agl_v = 0 ## No elevation angle
                    target.elevation = agl_v

                    ## snr
                    snr = frame[idx]
                    idx += 1
                    target.snr = snr
                    
                    ## reserved Byte
                    idx += 1

                    self.radar_data_msg_raw.targets.append( target )
            else:
                print u'Error Frame: checksum error'

if __name__ == "__main__":
    try:
        rospy.init_node( 'srd_d1_node')
        frame_id = rospy.get_param( '~frame_id', 'map' )
        device_id = rospy.get_param( '~device_id', '/dev/ttyUSB0' )
        firmware_version = rospy.get_param( '~firmware_version', 'old' )
        
        if firmware_version == 'old':
            with serial.Serial( device_id, 460800, timeout=1 ) as ser:
                intf = SRD_D1_MULTIPLE_UART( ser, frame_id )
                intf.data_receive_srd_d1_multiple()
        elif firmware_version == 'new':
            with serial.Serial( device_id, 115200, timeout=1 ) as ser:
                intf = SRD_D1_MULTIPLE_UART( ser, frame_id )
                intf.data_receive_srd_d1_us_launch()
        else:
            rospy.logerr( 'Invalid firmware_version parameter.' )

    except rospy.ROSInterruptException:
        pass

