#!/usr/bin/env python

import unittest
import rostest
import rospy
from geometry_msgs.msg import TransformStamped
from tf2_msgs.msg import TFMessage
from tf2_ros import Buffer, TransformListener, StaticTransformBroadcaster
from tf2_server.tf2_subtree_listener import TransformSubtreeListener


test_type = None


# There are actually just three tests, but the implementation requires six...
# test_basic really tests odom and body configs, whereas test_update tests only
# update config

class TestTF2SubtreeListener(unittest.TestCase):
    def test_basic(self):
        if test_type == "update":
            return

        buffer = Buffer()
        listener = TransformListener(buffer)

        rospy.sleep(rospy.Duration(1))

        if test_type == "odom":
            self.assertTrue( buffer.can_transform("odom", "base_link", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "left_track", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "front_left_flipper", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "right_track", rospy.Time(0)))

            self.assertFalse(buffer.can_transform("base_link", "left_track", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("base_link", "front_left_flipper", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("base_link", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("left_track", "front_left_flipper", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("left_track", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("front_left_flipper", "front_left_flipper_endpoint", rospy.Time(0)))

            self.assertFalse(buffer.can_transform("base_link", "right_track", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("base_link", "front_right_flipper", rospy.Time(0)))
        elif test_type == "body":
            self.assertFalse(buffer.can_transform("odom", "base_link", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "left_track", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "front_left_flipper", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "right_track", rospy.Time(0)))

            self.assertTrue(buffer.can_transform("base_link", "left_track", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("base_link", "front_left_flipper", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("base_link", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("left_track", "front_left_flipper", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("left_track", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("front_left_flipper", "front_left_flipper_endpoint", rospy.Time(0)))

            self.assertTrue(buffer.can_transform("base_link", "right_track", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("base_link", "front_right_flipper", rospy.Time(0)))

    def test_update(self):

        if test_type == "update":
            buffer = Buffer()
            listener = TransformListener(buffer)
            listener.unregister()
            buffer.clear()
            listener.tf_sub = rospy.Subscriber("/tf2_buffer_server/update", TFMessage, listener.callback, queue_size=None, buff_size=65536, tcp_nodelay=None)
            listener.tf_static_sub = rospy.Subscriber("/tf2_buffer_server/update/static", TFMessage, listener.static_callback, queue_size=None, buff_size=65536, tcp_nodelay=None)

            rospy.sleep(rospy.Duration(1))

            self.assertFalse(buffer.can_transform("map", "odom", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("map", "base_link", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("map", "left_track", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("map", "front_left_flipper", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("map", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("map", "right_track", rospy.Time(0)))

            self.assertFalse(buffer.can_transform("odom", "base_link", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "left_track", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "front_left_flipper", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("odom", "right_track", rospy.Time(0)))

            self.assertFalse(buffer.can_transform("base_link", "left_track", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("base_link", "front_left_flipper", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("base_link", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("left_track", "front_left_flipper", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("left_track", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("front_left_flipper", "front_left_flipper_endpoint", rospy.Time(0)))

            self.assertFalse(buffer.can_transform("base_link", "right_track", rospy.Time(0)))
            self.assertFalse(buffer.can_transform("base_link", "front_right_flipper", rospy.Time(0)))

            br = StaticTransformBroadcaster()
            tf = TransformStamped()
            tf.header.frame_id = "map"
            tf.child_frame_id = "odom"
            tf.transform.translation.x = tf.transform.translation.y = tf.transform.translation.z = 0
            tf.transform.rotation.x = tf.transform.rotation.y = tf.transform.rotation.z = 0
            tf.transform.rotation.w = 1
            br.sendTransform(tf)

            rospy.sleep(rospy.Duration(2))

            br.pub_tf.unregister()

            self.assertTrue(buffer.can_transform("map", "odom", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("map", "base_link", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("map", "left_track", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("map", "front_left_flipper", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("map", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("map", "right_track", rospy.Time(0)))

            self.assertTrue(buffer.can_transform("odom", "base_link", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("odom", "left_track", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("odom", "front_left_flipper", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("odom", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("odom", "right_track", rospy.Time(0)))

            self.assertTrue(buffer.can_transform("base_link", "left_track", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("base_link", "front_left_flipper", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("base_link", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("left_track", "front_left_flipper", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("left_track", "front_left_flipper_endpoint", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("front_left_flipper", "front_left_flipper_endpoint", rospy.Time(0)))

            self.assertTrue(buffer.can_transform("base_link", "right_track", rospy.Time(0)))
            self.assertTrue(buffer.can_transform("base_link", "front_right_flipper", rospy.Time(0)))


if __name__ == '__main__':
    rospy.init_node("test_tf2_subtree_listener_py")

    # Let the server buffer fill
    rospy.sleep(rospy.Duration(3))

    test_type = rospy.get_param("~test_type")
    rostest.run("tf2_server", "test_tf2_subtree_listener_py", TestTF2SubtreeListener)
