#!/usr/bin/env python
# -*- coding: utf-8 -*-

#
# Copyright (C) 2016 All Right Reserved, Southwest Research Institute® (SwRI®)
#

import fnmatch
import math
import os
import pyproj
import rospy
import yaml
from collections import namedtuple
from geometry_msgs.msg import PoseStamped
from gps_common.msg import GPSFix
from mapviz.srv import AddMapvizDisplay, AddMapvizDisplayRequest
from marti_common_msgs.msg import KeyValue

GeoReference = namedtuple("GeoReference", "path min_lat min_lon max_lat max_lon area")
_gps_fix = None

def distance(lat1, lon1, lat2, lon2):
    x1 = math.radians(lat1)
    y1 = math.radians(lon1)
    x2 = math.radians(lat2)
    y2 = math.radians(lon2)
    return math.acos(math.sin(x1) * math.sin(x2) + math.cos(x1) * math.cos(x2) * math.cos(y1 - y2))

def gps_callback(data):
    global _gps_fix
    _gps_fix = data

def load_tiles():
    rospy.init_node('mapviz_tile_loader', anonymous=True)

    base_directory = rospy.get_param('~base_directory', os.path.expanduser('~') + "/.ros")
    max_search_depth = rospy.get_param('~max_search_depth', 1)
    rate = max(0.1, rospy.get_param('~rate', 1.0))
    display_name = rospy.get_param('~display_name', 'satellite')
    draw_order = rospy.get_param('~draw_order', 1)
    use_local_xy = rospy.get_param('~use_local_xy', False)

    gps_sub = rospy.Subscriber("gps", GPSFix, gps_callback)

    geofiles = []
    initial_depth = base_directory.count(os.sep)
    for path, dirs, filenames in os.walk(base_directory):
        for filename in fnmatch.filter(filenames, '*.geo'):
            geofiles.append(os.path.join(path, filename))
        current_depth = path.count(os.sep) - initial_depth
        if current_depth >= max_search_depth:
          dirs[:] = []

    georeferences = []
    for geofile in geofiles:
      rospy.loginfo("Parsing %s...", geofile)
      f = open(geofile)
      min_lat = 90.0
      max_lat = -90.0
      min_lon = 180.0
      max_lon = -180.0
      geodata = yaml.safe_load(f)
      if str(geodata['projection']).lower() == 'utm':
        rospy.loginfo("  projection: utm")
        zone = 1
        band = 'N'
        if ('utm_zone' in geodata):
          zone = int(geodata['utm_zone'])
          if zone >= 1 or zone <= 60:
            rospy.loginfo("  utm zone: %d", zone)
          else:
            rospy.logwarn("  invalid utm zone!")
            f.close()
            continue
        else:
          rospy.logwarn("  no utm zone!")
          f.close()
          continue
        if ('utm_band' in geodata):
          band = str(geodata['utm_band'])
          if band >= 'C' or band <= 'X':
            rospy.loginfo("  utm band: %s", band)
          else:
            rospy.logwarn("  invalid utm band!")
            f.close()
            continue
        else:
          rospy.logwarn("  no utm band!")
          f.close()
          continue
          
        utm_proj = pyproj.Proj(proj='utm', zone=zone, ellps='WGS84', south=(band < 'N'))
          
        if 'tiepoints' in geodata:
            if len(geodata['tiepoints']) > 1:
                for point in geodata['tiepoints']:
                  easting = point['point'][2]
                  northing = point['point'][3]
                  lon, lat = utm_proj(easting, northing, inverse=True)
                  min_lon = min(lon, min_lon)
                  max_lon = max(lon, max_lon)
                  min_lat = min(lat, min_lat)
                  max_lat = max(lat, max_lat)
            else:
              rospy.logwarn("  not enough tiepoints!")
              f.close()
              continue            
        else:
            rospy.logwarn("  no tiepoints!")
            f.close()
            continue
      
      area = (max_lat - min_lat) * (max_lon - min_lon)
      
      georeferences.append(GeoReference(geofile, min_lat, min_lon, max_lat, max_lon, area))
      rospy.loginfo("  lat/lon bounds: (%lf, %lf) - (%lf, %lf)", min_lat, min_lon, max_lat, max_lon)
        
      f.close()

    rospy.loginfo("waiting for service: %s ...", rospy.resolve_name('add_mapviz_display'))
    rospy.wait_for_service('add_mapviz_display')

    last_path = None
    while not rospy.is_shutdown():
        if use_local_xy:
            if last_path is None and georeferences:
                # Select the geofile that contains the current local_xy origin point.  If multiple
                # geo-files contain the point, select the largest.  If no geo-files
                # contain the point, select the closest one.

                # get local xy origin
                local_xy = rospy.wait_for_message("/local_xy_origin", PoseStamped)
                max_area = 0
                path = None
                for georef in georeferences:
                    if local_xy.pose.position.y > georef.min_lat and local_xy.pose.position.y < georef.max_lat and local_xy.pose.position.x > georef.min_lon and local_xy.pose.position.x < georef.max_lon and georef.area > max_area:
                        path = georef.path
                        max_area = georef.area

                if path is None:
                    min_dist = float("inf")
                    for georef in georeferences:
                        lat = (georef.min_lat + georef.max_lat) / 2.0
                        lon = (georef.min_lon + georef.max_lon) / 2.0
                        dist = distance(lat, lon, local_xy.pose.position.y, local_xy.pose.position.x)
                        if (dist < min_dist):
                            min_dist = dist
                            path = georef.path

                # If the geo-file has changed, call the service for adding/updating
                # the mapviz display
                if path is not None and path != last_path:
                    rospy.loginfo("updating tileset: %s", path);
                    add_mapviz_display = rospy.ServiceProxy('add_mapviz_display', AddMapvizDisplay)
                    request = AddMapvizDisplayRequest(type='mapviz_plugins/multires_image', draw_order=draw_order, name=display_name, visible=True)
                    request.properties.append(KeyValue(key='path', value=path))
                    response = add_mapviz_display(request)
                    if response.success:
                        last_path = path
                    else:
                        rospy.logwarn("failed to update tileset: %s", response.message)
        else:
            if _gps_fix is not None:
                # Select the geofile that contains the current GPS point.  If multiple
                # geo-files contain the point, select the largest.  If no geo-files
                # contain the point, select the closest one.
                max_area = 0
                path = None
                for georef in georeferences:
                    if _gps_fix.latitude > georef.min_lat and _gps_fix.latitude < georef.max_lat and _gps_fix.longitude > georef.min_lon and _gps_fix.longitude < georef.max_lon and georef.area > max_area:
                        path = georef.path
                        max_area = georef.area

                if path is None:
                    min_dist = float("inf")
                    for georef in georeferences:
                        lat = (georef.min_lat + georef.max_lat) / 2.0
                        lon = (georef.min_lon + georef.max_lon) / 2.0
                        dist = distance(lat, lon, _gps_fix.latitude, _gps_fix.longitude)
                        if (dist < min_dist):
                             min_dist = dist
                             path = georef.path

                # If the geo-file has changed, call the service for adding/updating
                # the mapviz display
                if path is not None and path != last_path:
                    rospy.loginfo("updating tileset: %s", path);
                    add_mapviz_display = rospy.ServiceProxy('add_mapviz_display', AddMapvizDisplay)
                    request = AddMapvizDisplayRequest(type='mapviz_plugins/multires_image', draw_order=draw_order, name=display_name, visible=True)
                    request.properties.append(KeyValue(key='path', value=path))
                    response = add_mapviz_display(request)
                    if response.success:
                      last_path = path
                    else:
                      rospy.logwarn("failed to update tileset: %s", response.message)
            else:
                rospy.loginfo("waiting for gps message: %s", rospy.resolve_name('gps'))
        rospy.sleep(1.0 / rate)

if __name__ == '__main__':
    try:
        load_tiles()
    except rospy.ROSInterruptException: pass
