#!/usr/bin/env python
#
# License: BSD
#   https://raw.github.com/robotics-in-concert/rocon_multimaster/license/LICENSE
#

##############################################################################
# Imports
##############################################################################

import sys
import argparse
import re
import rospy
import rocon_gateway
import rocon_gateway_utils
import gateway_msgs.msg as gateway_msgs
import gateway_msgs.srv as gateway_srvs
import rocon_console.console as console
import rocon_python_comms


class Flags(object):
    flip = 'flip'
    cancel = 'cancel'

##############################################################################
# Logging
##############################################################################


class Logger(object):

    def __init(self):
        self.debug_flag = False

    def debug(self, msg):
        if self.debug_flag:
            console.pretty_print("%s\n" % msg, console.green)

    def green(self, msg):
        console.pretty_print("%s\n" % msg, console.green)

    def warning(self, msg):
        console.pretty_print("%s\n" % msg, console.yellow)

    def error(self, msg):
        console.pretty_print("%s\n" % msg, console.red)

logger = Logger()

##############################################################################
# Functions
##############################################################################


def parse_arguments():
    parser = argparse.ArgumentParser(description='Interactive tool for flipping a local connection to a remote gateway')
    parser.add_argument('-d', '--debug', action='store_true', help='print debugging information')
    args = parser.parse_args()
    logger.debug_flag = args.debug
    return args


def resolve_remote_gateways(gateway_namespace):
    '''
      @raise rocon_gateway.GatewayError: if no remote gateways or no matching gateways available.
    '''
    remote_gateways = []
    remote_gateway_info = rospy.ServiceProxy(gateway_namespace + '/remote_gateway_info', gateway_srvs.RemoteGatewayInfo)
    req = gateway_srvs.RemoteGatewayInfoRequest()
    req.gateways = []
    resp = remote_gateway_info(req)
    if len(resp.gateways) == 0:
        raise rocon_gateway.GatewayError("No remote gateways to flip/unflip to, aborting.")
    index = 0
    max_index = len(resp.gateways) - 1
    if max_index == 0:
        logger.warning("Automatically selecting target gateway %s (only one visible)" % resp.gateways[0].name)
        remote_gateways.append(resp.gateways[0].name)
    else:
        print("Select a target gateway:")
        for remote_gateway in resp.gateways:
            print("    " + str(index) + ") " + remote_gateway.name)
            index += 1
        input_response = raw_input("Enter [0-%s] or regex [e.g. pir.* or .*]: " % str(max_index))
        try:
            index = int(input_response)
            remote_gateways.append(resp.gateways[index].name)
        except ValueError:  # It's a regex
            for remote_gateway in resp.gateways:
                if re.match(input_response, remote_gateway.name):
                    remote_gateways.append(remote_gateway.name)
    return remote_gateways


def resolve_flags():
    '''
      Currently only flip/cancel flag
    '''
    response = raw_input("> Create or cancel a flip rule (f/c): ")
    while response != 'f' and response != 'c':
        response = raw_input("> Valid options are (f/c): ")
    if response == 'f':
        return Flags.flip
    else:
        return Flags.cancel


def resolve_flipped_interface(gateway_namespace):
    gateway_info = rocon_python_comms.SubscriberProxy(gateway_namespace + '/gateway_info', gateway_msgs.GatewayInfo)()
    return gateway_info.flipped_connections, gateway_info.flip_watchlist


def resolve_unflips(flip_watchlist):
    if len(flip_watchlist) == 0:
        raise rocon_gateway.GatewayError("No flip rules to cancel, aborting.")
    index = 0
    max_index = len(flip_watchlist) - 1
    unflips = []
    if max_index == 0:
        logger.warning("Automatically selecting %s (only one flip rule in the watchlist)" % flip_watchlist[0].rule.name)
        unflips.append(flip_watchlist[0])
    else:
        print("Select a flip rule to cancel (gateway-name-type-node):")
        for connection in flip_watchlist:
            if connection.rule.node == "None":
                connection.rule.node = r'.*'
            print("    " + str(index) + ") "),
            console.pretty_print(connection.gateway, console.red)
            print("-"),
            console.pretty_print(connection.rule.name, console.green)
            print("-"),
            console.pretty_print(connection.rule.type, console.cyan)
            print("-"),
            console.pretty_print(connection.rule.node + "\n", console.yellow)
            index += 1
        input_response = raw_input("Enter [0-%s]: " % str(max_index))
        index = int(input_response)
        unflips.append(flip_watchlist[index])
    return unflips


def is_in_flipped_connection_list(remote_gateways, connection, flipped_connections):
    matched_gateways = []
    for flipped_connection in flipped_connections:
        if flipped_connection.remote_rule.rule.name == connection.rule.name and \
           flipped_connection.remote_rule.rule.type == connection.rule.type and \
           flipped_connection.remote_rule.rule.node == connection.rule.node:
            if flipped_connection.remote_rule.gateway in remote_gateways:
                matched_gateways.append(flipped_connection.remote_rule.gateway)
                if len(matched_gateways) == len(remote_gateways):
                    return True
    return False


def resolve_connections(gateway_namespace, remote_gateways, flipped_connections):
    master = rocon_gateway.LocalMaster()
    connection_dictionary = master.get_connection_state()
    connections = []
    for connection_type in connection_dictionary:
        connections.extend(connection_dictionary[connection_type])
    connections[:] = [connection for connection in connections
                      if not is_in_flipped_connection_list(remote_gateways, connection, flipped_connections)]
    # some things you should never flip
    connections[:] = [connection for connection in connections if not re.match(
        '.*set_logger_level', connection.rule.name)]
    connections[:] = [connection for connection in connections if not re.match('.*get_loggers', connection.rule.name)]
    connections[:] = [connection for connection in connections if not re.match('.*rosout', connection.rule.name)]
    connections[:] = [connection for connection in connections
                      if not re.match(gateway_namespace + '/.*', connection.rule.name)]
    connections[:] = [connection for connection in connections if not re.match('.*zeroconf.*', connection.rule.name)]
    max_index = len(connections) - 1
    flips = []
    if max_index < 0:
        raise rocon_gateway.GatewayError("Nothing to flip (all either already flipped or blacklisted).")
    else:
        index = 0
        print("Select connection by index or regex name based pattern")
        for connection in connections:
            print("    " + str(index) + ") "),
            console.pretty_print(connection.rule.name, console.green)
            print("-"),
            console.pretty_print(connection.rule.type, console.cyan)
            print("-"),
            console.pretty_print(connection.rule.node + "\n", console.yellow)
            index += 1
        input_response = raw_input("Enter [0-%s] or regex keyed by name [e.g. /cha.* or .*]: " % str(max_index))
        try:
            index = int(input_response)
            for remote_gateway in remote_gateways:
                flips.append(gateway_msgs.RemoteRule(remote_gateway, connections[index].rule))
        except ValueError:  # It's a regex
            matched = False
            for connection in connections:
                if re.match(input_response, connection.rule.name):
                    matched = True
                    for remote_gateway in remote_gateways:
                        flips.append(gateway_msgs.RemoteRule(remote_gateway, connection.rule))
            if not matched:
                raise rocon_gateway.GatewayError("No matching connections found, aborting")
    return flips


def flip_rules(gateway_namespace, remote_rules, flag):
    flip_service = rospy.ServiceProxy(gateway_namespace + '/flip', gateway_srvs.Remote)
    req = gateway_srvs.RemoteRequest()
    for remote_rule in remote_rules:
        req.cancel = True if flag == Flags.cancel else False
        if remote_rule.rule.node == r'.*':  # ugly hack
            remote_rule.rule.node = ''
        req.remotes.append(remote_rule)
        unused_response = flip_service(req)

##############################################################################
# Main
##############################################################################

if __name__ == '__main__':

    rospy.init_node('flip')
    args = parse_arguments()
    gateway_namespace = None
    remote_gateways = None
    flag = None
    remote_rules = None

    try:
        gateway_namespace = rocon_gateway_utils.resolve_local_gateway()
        flag = resolve_flags()
        flipped_connections, flip_watchlist = resolve_flipped_interface(gateway_namespace)
        if flag == Flags.cancel:
            remote_rules = resolve_unflips(flip_watchlist)
        else:
            remote_gateways = resolve_remote_gateways(gateway_namespace)
            remote_rules = resolve_connections(gateway_namespace, remote_gateways, flipped_connections)
    except rocon_gateway.GatewayError, e:
        logger.error(str(e))
        sys.exit(1)
    console.pretty_print("Information\n", console.bold)
    print("  Local Gateway: %s" % gateway_namespace)
    print("  Operation    : %s" % flag)
    first_rule = True
    print("  Flip Rules   : "),
    for rule in remote_rules:
        if first_rule:
            first_rule = False
        else:
            print("               : "),
        console.pretty_print(rule.gateway, console.red)
        print("-"),
        console.pretty_print(rule.rule.name, console.green)
        print("-"),
        console.pretty_print(rule.rule.type, console.cyan)
        print("-"),
        console.pretty_print(rule.rule.node + "\n", console.yellow)
    proceed = raw_input("Proceed? (y/n): ")
    if proceed == 'y':
        flip_rules(gateway_namespace, remote_rules, flag)
    rospy.rostime.wallsleep(1)
