#!/usr/bin/env python

# For license see file LICENSE in the root of this repo.

from builtins import int

import json
import re
import requests

import rospy

from mikrotik_swos_tools.msg import Statistics


class SwOSAPI:
    def __init__(self):
        self._rate = rospy.Rate(rospy.get_param("~rate", 0.1))
        self._username = rospy.get_param("~username", "admin")
        self._password = rospy.get_param("~password", "")
        self._address = rospy.get_param("~address", "http://192.168.88.1/")
        if not self._address.startswith("http://"):
            self._address = "http://" + self._address
        if not self._address.endswith("/"):
            self._address += "/"

        self._identity = None
        self._port_names = None
        self._enabled_ports = None
        self._enabled_port_names = None
        self._enabled_port_numbers = None

        wait_for_switch = rospy.get_param("~wait_for_switch", False)
        if wait_for_switch:
            switch_ok = False
            while not rospy.is_shutdown() and not switch_ok:
                try:
                    self._collect_sys_tab_data()
                    switch_ok = True
                except requests.exceptions.RequestException as e:
                    rospy.logwarn("Waiting for switch management interface %s to become available." % self._address)
                    rospy.logdebug("Switch management interface error: " + str(e))
                    rospy.sleep(rospy.Duration(1, 0))
            if rospy.is_shutdown():
                return

        self._stats_pub = rospy.Publisher('~statistics', Statistics, queue_size=1)

        rospy.loginfo("Set up Mikrotik SwOS API on %s" % (self._address, ))

    def _request(self, url):
        response = requests.get(url, auth=requests.auth.HTTPDigestAuth(self._username, self._password))
        response.raise_for_status()
        return response

    def _api_request(self, url_command):
        response = self._request("%s%s" % (self._address, url_command))
        response = self._fix_broken_json(response.text)
        return json.loads(response)

    def _collect_sys_tab_data(self):
        response = self._api_request("sys.b")
        self._identity = self._decode_string(response["id"])
        rospy.loginfo("The connected switch is called '%s'" % self._identity)

    @property
    def identity(self):
        if self._identity is None:
            self._collect_sys_tab_data()
        return self._identity

    def _collect_link_tab_data(self):
        response = self._api_request("link.b")
        self._port_names = [self._decode_string(s) for s in response["nm"]]
        rospy.loginfo("The connected switch has ports '%r'" % self._port_names)

        enabled = int(response["en"], base=16)
        self._enabled_ports = list()
        port_bit = 1
        for _ in range(len(self._port_names)):
            self._enabled_ports.append((port_bit & enabled) != 0)
            port_bit = port_bit << 1

        self._enabled_port_names = list()
        self._enabled_port_numbers = list()
        all_i = 0
        for port_name in self._port_names:
            if self._enabled_ports[all_i]:
                self._enabled_port_names.append(port_name)
                self._enabled_port_numbers.append(all_i)
            all_i += 1

        rospy.loginfo("Ports %r with numbers %r are enabled" % (self._enabled_port_names, [i+1 for i in self._enabled_port_numbers]))

    @property
    def port_names(self):
        if self._port_names is None:
            self._collect_link_tab_data()
        return self._port_names

    @property
    def enabled_ports(self):
        if self._enabled_ports is None:
            self._collect_link_tab_data()
        return self._enabled_ports

    @property
    def enabled_port_names(self):
        if self._enabled_port_names is None:
            self._collect_link_tab_data()
        return self._enabled_port_names

    @property
    def enabled_port_numbers(self):
        if self._enabled_port_numbers is None:
            self._collect_link_tab_data()
        return self._enabled_port_numbers

    def _collect_statistics_tab_data(self):

        stats = self._api_request('!stats.b')
        for k, values in stats.items():
            stats[k] = [val for val, enabled in zip(values, self.enabled_ports) if enabled]

        result = Statistics()

        result.header.stamp = rospy.Time.now()
        result.header.frame_id = self.identity
        result.port_names = self.enabled_port_names
        result.port_numbers = self.enabled_port_numbers

        # rate
        result.rx_rate_bytes = [self._decode_byte_rate(d) / 8 for d in stats['rrb']]
        result.rx_rate_packets = [self._decode_packet_rate(d) for d in stats['rrp']]
        result.tx_rate_bytes = [self._decode_byte_rate(d) / 8 for d in stats['trb']]
        result.tx_rate_packets = [self._decode_packet_rate(d) for d in stats['trp']]

        # rx
        result.rx_bytes = self._get_long_stats(stats, 'rb')
        result.rx_packets = [self._decode_int(d) for d in stats['rtp']]
        result.rx_unicasts = self._get_long_stats(stats, 'rup')
        result.rx_broadcasts = self._get_long_stats(stats, 'rbp')
        result.rx_multicasts = self._get_long_stats(stats, 'rmp')
        result.rx_64 = [self._decode_int(d) for d in stats['r64']]
        result.rx_127 = [self._decode_int(d) for d in stats['r65']]
        result.rx_255 = [self._decode_int(d) for d in stats['r128']]
        result.rx_511 = [self._decode_int(d) for d in stats['r256']]
        result.rx_1023 = [self._decode_int(d) for d in stats['r512']]
        result.rx_1518 = [self._decode_int(d) for d in stats['r1k']]
        result.rx_jumbo = [self._decode_int(d) for d in stats['rmax']]

        # tx
        result.tx_bytes = self._get_long_stats(stats, 'tb')
        result.tx_packets = [self._decode_int(d) for d in stats['ttp']]
        result.tx_unicasts = self._get_long_stats(stats, 'tup')
        result.tx_broadcasts = self._get_long_stats(stats, 'tbp')
        result.tx_multicasts = self._get_long_stats(stats, 'tmp')
        result.tx_64 = [self._decode_int(d) for d in stats['t64']]
        result.tx_127 = [self._decode_int(d) for d in stats['t65']]
        result.tx_255 = [self._decode_int(d) for d in stats['t128']]
        result.tx_511 = [self._decode_int(d) for d in stats['t256']]
        result.tx_1023 = [self._decode_int(d) for d in stats['t512']]
        result.tx_1518 = [self._decode_int(d) for d in stats['t1k']]
        result.tx_jumbo = [self._decode_int(d) for d in stats['tmax']]

        return result

    def run(self):
        while not rospy.is_shutdown():
            try:
                self._stats_pub.publish(self._collect_statistics_tab_data())
            except requests.exceptions.RequestException as e:
                rospy.logerr("Error communicating with Mikrotik SwOS API: " + str(e))
            self._rate.sleep()

    @staticmethod
    def _decode_byte_rate(s):
        return int(s, base=16) * 10

    @staticmethod
    def _decode_packet_rate(s):
        return int(s, base=16)

    @staticmethod
    def _decode_int(s):
        return int(s, base=16)

    @staticmethod
    def _decode_long(s_low, s_high):
        return int("0x" + s_high[2:] + s_low[2:], base=16)

    def _get_long_stats(self, stats, param):
        return [self._decode_long(d, dh) for d, dh in zip(stats[param], stats[param + 'h'])]

    @staticmethod
    def _decode_string(s):
        result = ""
        for i in range(len(s) // 2):
            strnum = s[2 * i:(2 * i + 2)]
            num = int(strnum, base=16)
            ch = chr(num)
            result += ch
        return result

    @staticmethod
    def _fix_broken_json(broken_json):
        result = re.sub(r'([{,])([a-zA-Z][a-zA-Z0-9]+)', '\\1"\\2"', broken_json)
        result = re.sub(r'\'', '"', result)
        result = re.sub(r'(0x[0-9a-zA-Z]+)', '"\\1"', result)

        return result


if __name__ == '__main__':
    rospy.init_node("swos_api")
    api = SwOSAPI()

    try:
        api.run()
    except rospy.ROSInterruptException:
        pass
