#!/usr/bin/env python

from __future__ import print_function

import os
import sys

import rosbag
import rospy

from rospy.client import load_command_line_node_params


params = load_command_line_node_params(sys.argv)

in_bag_file = params.get('in_bag', None)
out_bag_file = params.get('out_bag', None)
topics = params.get('topics', "")
delay = params.get('delay', 0.0)
overwrite_existing = params.get('overwrite_existing', False)

topics_list = topics.split(",") if len(topics) > 0 else list()

if in_bag_file is None or out_bag_file is None:
    print("Usage: "
          "fix_bag_timestamps _in_bag:=/path/to.bag _out_bag:=/path/to/out.bag "
          "[_topics:='/topic1,/topic2'] [_delay:=0.0] [_overwrite_existing:=false]\n\n"
          "Goes through `in_bag` and for all messages with a header field sets their publication time to the time "
          "stored in their header.stamp plus `delay`.\nIf `topics` are set, only work on messages on the listed topics."
          "\nAn existing `out_bag` is overwritten only if `overwrite_existing` is set to true.", file=sys.stderr)
    sys.exit(1)

if not os.path.exists(in_bag_file):
    print("Error: `in_bag` '{}' does not exist.".format(in_bag_file), file=sys.stderr)
    sys.exit(2)

if os.path.exists(out_bag_file) and not overwrite_existing:
    print("Error: `out_bag` '{}' exists, will not overwrite it.".format(out_bag_file), file=sys.stderr)
    sys.exit(3)

if not out_bag_file.endswith(".bag"):
    print("Warning: `out_bag` '{}' does not end with '.bag'.".format(out_bag_file), file=sys.stderr)

delay_duration = rospy.Duration.from_sec(delay)
num_processed_messages = 0
num_changed_timestamps = 0

with rosbag.Bag(in_bag_file) as in_bag:
    with rosbag.Bag(out_bag_file, 'w') as out_bag:
        for topic, msg, t in in_bag.read_messages():
            num_processed_messages += 1
            if len(topics_list) == 0 or topic in topics_list:
                if hasattr(msg, 'header'):
                    t = msg.header.stamp + delay_duration
                    num_changed_timestamps += 1
                elif topic == "/tf" and msg.transforms:
                    t = msg.transforms[0].header.stamp + delay_duration
            out_bag.write(topic, msg, t)

print("Finished rewriting bagfile, processed {} messages, changed timestamps for {} of them.".format(
    num_processed_messages, num_changed_timestamps))
