#!/usr/bin/env python
#
# License: BSD
#   https://raw.github.com/robotics-in-concert/rocon_multimaster/license/LICENSE
#

##############################################################################
# Imports
##############################################################################

import sys
import argparse
import re
import rospy
import rocon_gateway
import rocon_gateway_utils
import gateway_msgs.msg as gateway_msgs
import gateway_msgs.srv as gateway_srvs
import rocon_console.console as console
import rocon_python_comms

##############################################################################
# Flags
##############################################################################


class Flags(object):
    pull = 'pull'
    cancel = 'cancel'

##############################################################################
# Logging
##############################################################################


class Logger(object):

    def __init(self):
        self.debug_flag = False

    def debug(self, msg):
        if self.debug_flag:
            console.pretty_print("%s\n" % msg, console.green)

    def green(self, msg):
        console.pretty_print("%s\n" % msg, console.green)

    def warning(self, msg):
        console.pretty_print("%s\n" % msg, console.yellow)

    def error(self, msg):
        console.pretty_print("%s\n" % msg, console.red)

logger = Logger()

##############################################################################
# Functions
##############################################################################


def parse_arguments():
    parser = argparse.ArgumentParser(description='Interactive tool for pulling from remote gateway advertisements')
    parser.add_argument('-d', '--debug', action='store_true', help='print debugging information')
    args = parser.parse_args()
    logger.debug_flag = args.debug
    return args


def resolve_remote_gateways(gateway_namespace):
    '''
      @raise rocon_gateway.GatewayError: if no remote gateways or no matching gateways available.
    '''
    remote_gateway_info = rospy.ServiceProxy(gateway_namespace + '/remote_gateway_info', gateway_srvs.RemoteGatewayInfo)
    req = gateway_srvs.RemoteGatewayInfoRequest()
    req.gateways = []
    resp = remote_gateway_info(req)
    if len(resp.gateways) == 0:
        raise rocon_gateway.GatewayError("No remote gateways to pull from, aborting.")
    index = 0
    max_index = len(resp.gateways) - 1
    if max_index == 0:
        logger.debug("Automatically selecting target gateway %s (only one visible)" % resp.gateways[0].name)
        return resp.gateways[0].name, resp.gateways[0].public_interface
    else:
        print("Select a target gateway:")
        for remote_gateway in resp.gateways:
            print("    " + str(index) + ") " + remote_gateway.name)
            index += 1
        input_response = raw_input("Enter [0-%s]: " % str(max_index))
        index = int(input_response)
        return resp.gateways[index].name, resp.gateways[index].public_interface


def resolve_flags():
    response = raw_input("> Pull or cancel a pulled connections (p/c): ")
    while response != 'p' and response != 'c':
        response = raw_input("> Valid options are (p/c): ")
    if response == 'p':
        return Flags.pull
    else:
        return Flags.cancel


def resolve_pulled_interface(gateway_namespace):
    gateway_info = rocon_python_comms.SubscriberProxy(gateway_namespace + '/gateway_info', gateway_msgs.GatewayInfo)()
    return gateway_info.pulled_connections, gateway_info.pull_watchlist


def resolve_unpulls(pull_watchlist):
    if len(pull_watchlist) == 0:
        raise rocon_gateway.GatewayError("No pull rules to cancel, aborting.")
    index = 0
    max_index = len(pull_watchlist) - 1
    unpulls = []
    if max_index == 0:
        logger.debug("Automatically selecting %s (only one visible)" % pull_watchlist[0].rule.name)
        unpulls.append(gateway_msgs.RemoteRule(pull_watchlist[0].gateway, pull_watchlist[0].rule))
    else:
        print("Select a pull rule to cancel (gateway-name-type-node):")
        for connection in pull_watchlist:
            if connection.rule.node == "None":
                connection.rule.node = r'.*'
            print("    " + str(index) + ") "),
            console.pretty_print(connection.gateway, console.red)
            print("-"),
            console.pretty_print(connection.rule.name, console.green)
            print("-"),
            console.pretty_print(connection.rule.type, console.cyan)
            print("-"),
            console.pretty_print(connection.rule.node + "\n", console.yellow)
            index += 1
        input_response = raw_input("Enter [0-%s]: " % str(max_index))
        index = int(input_response)
        unpulls.append(pull_watchlist[index])
    return unpulls


def is_in_pulled_connection_list(remote_gateway, advertisement, pulled_connections):
    for pulled_connection in pulled_connections:
        if pulled_connection.rule.name == advertisement.name and \
           pulled_connection.rule.type == advertisement.type and \
           pulled_connection.rule.node == advertisement.node:
            if pulled_connection.gateway == remote_gateway:
                return True
    return False


def resolve_connections(gateway_namespace, remote_gateway, remote_advertisements, pulled_connections):

    advertisements = []
    advertisements[:] = [advertisement for advertisement in remote_advertisements
                         if not is_in_pulled_connection_list(remote_gateway, advertisement, pulled_connections)]
    advertisements[:] = [advertisement for advertisement in advertisements
                         if not re.match('.*set_logger_level', advertisement.name)]
    advertisements[:] = [advertisement for advertisement in advertisements
                         if not re.match('.*get_loggers', advertisement.name)]
    advertisements[:] = [advertisement for advertisement in advertisements
                         if not re.match('.*rosout', advertisement.name)]
    advertisements[:] = [advertisement for advertisement in advertisements
                         if not re.match(gateway_namespace + '/.*', advertisement.name)]
    advertisements[:] = [advertisement for advertisement in advertisements
                         if not re.match('.*zeroconf.*', advertisement.name)]
    max_index = len(advertisements) - 1
    pulls = []
    if max_index == 0:
        logger.warning("Automatically selecting %s (only one visible)" % advertisements[0].name)
        pulls.append(gateway_msgs.RemoteRule(remote_gateway, advertisements[0]))
    elif max_index < 0:
        raise rocon_gateway.GatewayError("Nothing to pull (either already flipped or blacklisted).")
    else:
        index = 0
        print("Select a advertisement to pull (name-type-node):")
        for advertisement in advertisements:
            print("    " + str(index) + ") "),
            console.pretty_print(advertisement.name, console.green)
            print("-"),
            console.pretty_print(advertisement.type, console.cyan)
            print("-"),
            console.pretty_print(advertisement.node + "\n", console.yellow)
            index += 1
        input_response = raw_input("Enter [0-%s] or regex keyed by name [e.g. /cha.* or .*]: " % str(max_index))
        try:
            index = int(input_response)
            pulls.append(gateway_msgs.RemoteRule(remote_gateway, advertisements[index]))
        except ValueError:  # It's a regex
            for advertisement in advertisements:
                if re.match(input_response, advertisement.rule.name):
                    for remote_gateway in remote_gateways:
                        pulls.append(gateway_msgs.RemoteRule(remote_gateway, advertisement.rule))
    return pulls


def pull_rules(gateway_namespace, remote_rules, flag):
    pull_service = rospy.ServiceProxy(gateway_namespace + '/pull', gateway_srvs.Remote)
    req = gateway_srvs.RemoteRequest()
    for remote_rule in remote_rules:
        req.cancel = True if flag == Flags.cancel else False
        if remote_rule.rule.node == r'.*':  # ugly hack
            remote_rule.rule.node = ''
        req.remotes.append(remote_rule)
        unused_response = pull_service(req)

##############################################################################
# Main
##############################################################################

if __name__ == '__main__':

    rospy.init_node('pull')
    args = parse_arguments()
    gateway_namespace = None
    remote_gateways = None
    flag = None
    remote_rules = None

    try:
        gateway_namespace = rocon_gateway_utils.resolve_local_gateway()
        flag = resolve_flags()
        pulled_connections, pull_watchlist = resolve_pulled_interface(gateway_namespace)
        if flag == Flags.cancel:
            remote_rules = resolve_unpulls(pull_watchlist)
        else:
            remote_gateway, remote_advertisements = resolve_remote_gateways(gateway_namespace)
            remote_rules = resolve_connections(
                gateway_namespace, remote_gateway, remote_advertisements, pulled_connections)
    except rocon_gateway.GatewayError, e:
        logger.error(str(e))
        sys.exit(1)
    console.pretty_print("Information\n", console.bold)
    print("  Local Gateway: %s" % gateway_namespace)
    print("  Operation    : %s" % flag)
    first_rule = True
    print("  Pull Rules   : "),
    for remote_rule in remote_rules:
        if first_rule:
            first_rule = False
        else:
            print("               : "),
        console.pretty_print(remote_rule.gateway, console.red)
        print("-"),
        console.pretty_print(remote_rule.rule.name, console.green)
        print("-"),
        console.pretty_print(remote_rule.rule.type, console.cyan)
        print("-"),
        console.pretty_print(remote_rule.rule.node + "\n", console.yellow)
    proceed = raw_input("Proceed? (y/n): ")
    if proceed == 'y':
        pull_rules(gateway_namespace, remote_rules, flag)
    rospy.rostime.wallsleep(1)
