#!/usr/bin/env python

from collections import OrderedDict

import rospy

from tf2_msgs.msg import TFMessage
from std_msgs.msg import Bool


class StaticTransformMux(object):
    """This node subscribes to /tf_static, collects all the transforms that are ever published there, and re-sends a
    message that contains not only the transforms from the last publisher, but all transforms ever encountered.
    """
    def __init__(self):
        """Initialize all publishers/subscribers."""

        self._ready = False

        self._transforms = OrderedDict()

        self._update_only_with_newer = rospy.get_param("~update_only_with_newer", False)
        self._forbidden_callerid_prefix = rospy.get_param("~forbidden_callerid_prefix", None)
        self._pub_topic = rospy.get_param("~publisher_topic", "/tf_static")

        self._tf_publisher = rospy.Publisher(self._pub_topic, TFMessage, latch=True, queue_size=100)
        self._ready_publisher = rospy.Publisher("static_transform_mux/ready", Bool, latch=True, queue_size=1)
        self._tf_subscriber = rospy.Subscriber("/tf_static", TFMessage, self.tf_static_cb, queue_size=100)

    def tf_static_cb(self, tf_msg):
        """Callback for messages from /tf_static.

        Add all transforms to the cache and then publish a TF message with all the transforms.
        :param TFMessage tf_msg: A /tf_static message.
        """

        # Protect against reacting to this node's own published messages, otherwise it'd end up in an infinite loop.
        callerid = tf_msg._connection_header['callerid']
        if callerid == rospy.get_name() or \
                (self._forbidden_callerid_prefix is not None and callerid.startswith(self._forbidden_callerid_prefix)):
            return

        # Process the incoming transforms, merge them with our cache.
        for transform in tf_msg.transforms:
            key = transform.child_frame_id
            if key not in self._transforms or not self._update_only_with_newer or \
                    self._transforms[key].header.stamp < transform.header.stamp:
                rospy.logdebug("Updated transform %s->%s. Old timestamp was %s, new timestamp is %s." % (
                    transform.header.frame_id, transform.child_frame_id,
                    str(self._transforms[key].header.stamp if key in self._transforms else "None"),
                    str(transform.header.stamp)))

                self._transforms[key] = transform

        # Publish the cached messages.
        msg = TFMessage()
        msg.transforms = self._transforms.values()
        self._tf_publisher.publish(msg)

        rospy.logdebug("Sent %i transforms: %s" % (len(self._transforms), str(self._transforms.keys())))

        # After publishing the first message, send also this ~/ready message, so that nodes that rely on this node's
        # operation know that it has been brought up.
        if not self._ready:
            self._ready = True
            ready_msg = Bool()
            ready_msg.data = True
            self._ready_publisher.publish(ready_msg)
            rospy.logdebug("Sent ready message")


if __name__ == '__main__':
    rospy.init_node('static_transform_mux')
    mux = StaticTransformMux()
    rospy.spin()
