#!/usr/bin/env python

import urllib3
import requests

from functools import partial

import rospy
from std_msgs.msg import Header
from ubnt_airos_tools.msg import Status, Station, Stations, Interface, Interfaces


def encode_multipart_formdata(fields):
    boundary = "---------------------------13049778497274285981867756397"

    body = (
            "".join("--%s\r\n"
                    "Content-Disposition: form-data; name=\"%s\"\r\n"
                    "\r\n"
                    "%s\r\n" % (boundary, field, value)
                    for field, value in fields.items()) +
            "--%s--\r\n" % boundary
    )

    content_type = "multipart/form-data; boundary=%s" % boundary

    return body, content_type


class AirosAPI(object):
    def __init__(self):
        self._username = rospy.get_param("~username", "ubnt")
        self._password = rospy.get_param("~password", "ubntubnt")
        self._address = rospy.get_param("~address", "http://192.168.1.1")

        # by default, we read frame_id from the device's hostname
        self._frame_id = rospy.get_param("~frame_id", None)
        self._rate = rospy.Rate(rospy.get_param("~publish_rate", 0.1))
        self._compute_rates = rospy.get_param("~compute_rates", True)

        self._status_pub = rospy.Publisher("~status", Status, queue_size=1)
        self._interfaces_pub = rospy.Publisher("~interfaces", Interfaces, queue_size=1)
        self._stations_pub = rospy.Publisher("~stations", Stations, queue_size=1)

        self._last_interface_values = dict()
        self._last_station_values = dict()

        self._session = requests.Session()

        self._logged_in = False

    def _login(self):
        self._session.cookies.clear()

        rospy.loginfo("Started login sequence")

        expect = self._session.post("%s/login.cgi" % self._address,
                                    verify=False, allow_redirects=False,
                                    headers={'expect': ''})
        expect.raise_for_status()

        formdata = {
            'uri': '/status.cgi',
            'username': self._username,
            'password': self._password
        }

        body, content_type = encode_multipart_formdata(formdata)

        headers = {
            'content-type': content_type,
        }

        login = self._session.post(
            "%s/login.cgi" % self._address, verify=False, allow_redirects=True,
            data=body, headers=headers
        )
        login.raise_for_status()

        if login.headers['content-type'] == 'application/json':
            self._logged_in = True
            rospy.loginfo("Login successful")
        else:
            rospy.logerr("Error logging in")
            raise requests.exceptions.HTTPError()

    def _get_rate(self, values, name, key, new_value):
        id = name + key
        id_found = id in values
        old_value = values[id] if id_found else None
        values[id] = new_value
        if id_found:
            rate = (new_value - old_value) / self._rate.sleep_dur.secs
            return rate
        else:
            return None

    def request(self, api_path):
        if not self._logged_in:
            self._login()

        request = "%s/%s" % (self._address, api_path)

        # if re-login is needed, we need to repeat the request
        for i in [1, 2]:
            rospy.logdebug("Requesting %s" % request)
            response = self._session.get(request)
            # returned arrays are reported as text/html
            if response.status_code == 403 or (response.headers['content-type'] != 'application/json' and response.content[0] != '['):
                self._logged_in = False
                rospy.loginfo("Re-login needed")
                self._login()
            response.raise_for_status()
            return response

    def _publish_status(self, header, status=None):
        if self._status_pub.get_num_connections() > 0:
            try:
                if status is None:
                    status = self.request('status.cgi').json()
                msg = Status()
                msg.header = header
                msg.tx_rate = float(status['wireless']['txrate'])
                msg.rx_rate = float(status['wireless']['rxrate'])
                msg.ccq = status['wireless']['ccq'] / 10.0
                msg.rssi = status['wireless']['rssi']
                msg.signal = status['wireless']['signal']
                msg.cpu_load = status['host']['cpuload']
                self._status_pub.publish(msg)
            except requests.exceptions.RequestException, e:
                rospy.logerr("Network error: " + str(e))

    def _publish_interfaces(self, header):
        if self._interfaces_pub.get_num_connections() > 0:
            try:
                iflist = self.request('iflist.cgi').json()
                msg = Interfaces()
                msg.header = header
                some_rates_missing = False
                for interface in iflist['interfaces']:
                    interface_msg = Interface()
                    interface_msg.name = interface['ifname'].encode('ascii')

                    interface_msg.rx_bytes = int(interface['stats']['rx_bytes'])
                    interface_msg.rx_packets = int(interface['stats']['rx_packets'])
                    interface_msg.rx_errors = int(interface['stats']['rx_errors'])
                    interface_msg.rx_multicast = int(interface['stats']['rx_multicast'])
                    interface_msg.rx_dropped = int(interface['stats']['rx_dropped'])

                    interface_msg.tx_bytes = int(interface['stats']['tx_bytes'])
                    interface_msg.tx_packets = int(interface['stats']['tx_packets'])
                    interface_msg.tx_errors = int(interface['stats']['tx_errors'])
                    interface_msg.tx_dropped = int(interface['stats']['tx_dropped'])

                    if self._compute_rates:
                        get_rate = partial(self._get_rate, self._last_interface_values, interface_msg.name)

                        interface_msg.rx_rate_bytes = get_rate('rx_bytes', interface_msg.rx_bytes)
                        interface_msg.tx_rate_bytes = get_rate('tx_bytes', interface_msg.tx_bytes)
                        interface_msg.rx_rate_packets = get_rate('rx_packets', interface_msg.rx_packets)
                        interface_msg.tx_rate_packets = get_rate('tx_packets', interface_msg.tx_packets)

                        if interface_msg.rx_rate_bytes is None or interface_msg.tx_rate_bytes is None \
                                or interface_msg.tx_rate_packets is None or interface_msg.rx_rate_packets is None:
                            some_rates_missing = True

                    msg.interfaces.append(interface_msg)

                if not some_rates_missing:
                    msg.interfaces = sorted(msg.interfaces, key=lambda i: i.name)
                    self._interfaces_pub.publish(msg)

            except requests.exceptions.RequestException, e:
                rospy.logerr("Network error: " + str(e))
        else:
            self._last_interface_values.clear()

    def _publish_stations(self, header):
        if self._stations_pub.get_num_connections() > 0:
            try:
                stalist = self.request('sta.cgi').json()
                msg = Stations()
                msg.header = header
                for station in stalist:
                    station_msg = Station()
                    station_msg.mac = station['mac'].encode('ascii')
                    station_msg.name = station['name'].encode('ascii')
                    station_msg.last_ip = station['lastip'].encode('ascii')
                    station_msg.signal = station['signal']
                    station_msg.rssi = station['rssi']
                    station_msg.ccq = station['ccq']
                    station_msg.rx_packets = station['stats']['rx_data']
                    station_msg.rx_bytes = station['stats']['rx_bytes']
                    station_msg.rx_rate_packets = station['stats']['rx_pps']
                    station_msg.tx_packets = station['stats']['tx_data']
                    station_msg.tx_bytes = station['stats']['tx_bytes']
                    station_msg.tx_rate_packets = station['stats']['tx_pps']

                    if self._compute_rates:
                        get_rate = partial(self._get_rate, self._last_station_values, station_msg.mac)

                        station_msg.rx_rate_bytes = get_rate('rx_bytes', station_msg.rx_bytes)
                        station_msg.tx_rate_bytes = get_rate('tx_bytes', station_msg.tx_bytes)

                        # in stations, we don't want to stop publishing the whole message just because
                        # a new station popped up, so we just skip the one station
                        if station_msg.rx_rate_bytes is None or station_msg.tx_rate_bytes is None:
                            continue

                    msg.stations.append(station_msg)
                msg.stations = sorted(msg.stations, key=lambda s: s.mac)
                self._stations_pub.publish(msg)
            except requests.exceptions.RequestException, e:
                rospy.logerr("Network error: " + str(e))
        else:
            self._last_station_values.clear()

    def publish(self):
        while not rospy.is_shutdown():
            status = None
            if self._frame_id is None:
                try:
                    status = self.request('status.cgi').json()
                    self._frame_id = status['host']['hostname'].encode('ascii')
                    rospy.loginfo("Setting frame_id to " + self._frame_id)
                except Exception, e:
                    rospy.logerr("Could not detect hostname of the device, using address")
                    self._frame_id = self._address

            header = Header()
            header.stamp = rospy.Time.now()
            header.frame_id = self._frame_id

            self._publish_status(header, status)
            self._publish_interfaces(header)
            self._publish_stations(header)

            self._rate.sleep()


if __name__ == '__main__':
    urllib3.disable_warnings()

    rospy.init_node('airos_api')

    api = AirosAPI()
    api.publish()