#!/usr/bin/env python
# Copyright (c) 2016 The UUV Simulator Authors.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import rospy
from sensor_msgs.msg import PointCloud
from visualization_msgs.msg import MarkerArray
from geometry_msgs.msg import Point, TwistStamped
from threading import Lock
from uuv_plume_model import Plume
import numpy as np
from uuv_plume_msgs.srv import *
import os
import yaml


class PlumeSimulatorServer:
    """
    The plume server allows a plume entity to be created and configured in the
    on going simulation using ROS services. This class also contains timer
    callback functions to update the visual markers and plume point clouds
    that can be visualized in RViz.

    Examples
    --------

    To start this plume server, be sure that Gazebo has been initialized and
    then run::

        $ roslaunch uuv_plume_simulator start_plume_server.launch

    Check the files *scripts/set_demo_spheroid_plume* and
    *scripts/set_demo_turbulent_plume* to see examples on how to use the
    service calls to create plumes.

    .. note::

        Check the chemical particle concentration sensor plugin (*CPCROSPlugin*)
        to generate simulated sensor data for the particle concentration around
        the vehicle.
    """
    def __init__(self):
        # Plume model (to be defined by a service call)
        self._model = None

        # Current time stamp and time step
        self._t = rospy.get_time()
        self._dt = 0.0

        # Update rate for the current plume state
        self._update_rate = 5.0
        if rospy.has_param('~update_rate'):
            self._update_rate = float(rospy.get_param('~update_rate'))
            self._update_rate = max(0.05, self._update_rate)
        rospy.loginfo('Update rate [Hz]: %.3f', self._update_rate)

        # Definition of service callbacks
        self._services = dict()

        # Service callback to create a static plume in the form of a
        # spheroid
        self._services['create_spheroid_plume'] = rospy.Service(
            'create_spheroid_plume',
            CreateSpheroidPlume,
            self.create_spheroid_plume)

        # Service callback to create a dynamic passive scalar turbulent plume
        self._services['create_passive_scalar_turbulent_plume'] = rospy.Service(
            'create_passive_scalar_turbulent_plume',
            CreatePassiveScalarTurbulentPlume,
            self.create_passive_scalar_turbulent_plume)

        # Service callback to configure the bounding box limiting the plume
        # source and particles
        self._services['set_plume_limits'] = rospy.Service(
            'set_plume_limits',
            SetPlumeLimits,
            self.set_plume_limits)

        # Service callback to change the maximum number of particles to be
        # generated by a single plume and, in the case of the dynamic plume,
        # the maximum number of particles generated at each iteration
        self._services['set_plume_config'] = rospy.Service(
            'set_plume_config',
            SetPlumeConfiguration,
            self.set_plume_configuration)

        # Service callback to return general plume configuration parameters,
        # such as maximum number of particles, model name, etc.
        self._services['get_plume_config'] = rospy.Service(
            'get_plume_config',
            GetPlumeConfiguration,
            self.get_plume_configuration)

        # Service callback to the delete the plume model, after this is called
        # the output plume point cloud will be empty
        self._services['delete_plume'] = rospy.Service(
            'delete_plume',
            DeletePlume,
            self.delete_plume)

        # Service callback to set the position of the plume source
        self._services['set_plume_source_position'] = rospy.Service(
            'set_plume_source_position',
            SetPlumeSourcePosition,
            self.set_plume_source_position)

        # Service callback to get the position of the plume source
        self._services['get_plume_source_position'] = rospy.Service(
            'get_plume_source_position',
            GetPlumeSourcePosition,
            self.get_plume_source_position)

        self._services['get_num_particles'] = rospy.Service(
            'get_num_particles',
            GetNumParticles,
            self.get_num_particles)

        self._services['store_plume_state'] = rospy.Service(
            'store_plume_state',
            StorePlumeState,
            self.store_plume_state)

        self._services['load_plume_particles'] = rospy.Service(
            'load_plume_particles',
            LoadPlumeParticles,
            self.load_plume_particles)

        # Publisher for the plume visual markers
        self._plume_marker_publisher = rospy.Publisher(
            'markers',
            MarkerArray,
            queue_size=1)

        # Publisher for the plume particles point cloud
        self._plume_point_cloud_publisher = rospy.Publisher(
            'particles',
            PointCloud,
            queue_size=1)

        # Subscriber to the current velocity
        self._current_vel_subscriber = rospy.Subscriber(
            'current_vel',
            TwistStamped,
            self.current_vel_callback)

        self._loading_plume = Lock()
        # Timer called to update both the particle point cloud and visual
        # markers
        self._update_plume_timer = rospy.Timer(
            rospy.Duration(1 / self._update_rate),
            self.update_plume)

    def current_vel_callback(self, msg):
        """
        Subscriber callback function for the current velocity vector update.
        """
        if self._model is None:
            return

        self._model.update_current_vel([msg.twist.linear.x,
                                        msg.twist.linear.y,
                                        msg.twist.linear.z])

    def delete_plume(self, request):
        """
        Service function callback to delete the plume model. All markers and
        point cloud will be published with empty topics.
        """
        if self._model is not None:
            del self._model
        self._model = None
        rospy.loginfo('Plume deleted')
        return DeletePlumeResponse(True)

    def set_plume_source_position(self, request):
        """
        Service function callback that sets a new position for the plume
        source wrt the ENU frame
        """
        if self._model is None:
            rospy.logwarn('No plume model has been created')
            return SetPlumeSourcePositionResponse(False)
        else:
            rospy.loginfo('Plume source position set to:')
            rospy.loginfo('\t(X, Y, Z) [m] wrt the ENU frame: (%.2f, %.2f, %.2f)' % (request.source.x, request.source.y, request.source.z))

            self._model.source_pos = \
                [request.source.x, request.source.y, request.source.z]
            return SetPlumeSourcePositionResponse(True)

    def get_num_particles(self, request):
        """
        Return the number of particles currently active.
        """
        if self._model is None:
            rospy.logwarn('No plume model has been created')
            return GetNumParticlesResponse(0)
        else:
            rospy.loginfo('Number of plume particles requested: %d' % self._model.num_particles)
            return GetNumParticlesResponse(int(self._model.num_particles))

    def get_plume_source_position(self, request):
        """
        Service function callback that returns the position wrt to the ENU
        frame for the plume source.
        """
        if self._model is None:
            rospy.logwarn('No plume model has been created')
            return GetPlumeSourcePositionResponse(Point(-1, -1, -1))
        else:
            rospy.loginfo('Current plume source position:')
            rospy.loginfo('\t(X, Y, Z) [m] wrt the ENU frame: (%.2f, %.2f, %.2f)' % (self._model.source_pos[0], self._model.source_pos[1], self._model.source_pos[2]))

            return GetPlumeSourcePositionResponse(
                Point(self._model.source_pos[0],
                      self._model.source_pos[1],
                      self._model.source_pos[2]))

    def get_plume_configuration(self, request):
        """
        Service function callback to return the configuration parameters for
        the current plume model being used.
        """
        if self._model is None:
            rospy.logwarn('No plume model has been created')
            return GetPlumeConfigurationResponse(
                '',
                0,
                0,
                Point(0, 0, 0),
                0, 0, 0, 0, 0, 0)
        else:
            return GetPlumeConfigurationResponse(
                self._model.LABEL,
                self._model.n_points,
                (0 if self._model.LABEL == 'spheroid' else \
                    self._model.max_particles_per_iter),
                Point(self._model.source_pos[0],
                      self._model.source_pos[1],
                      self._model.source_pos[2]),
                self._model.x_lim[0],
                self._model.x_lim[1],
                self._model.y_lim[0],
                self._model.y_lim[1],
                self._model.z_lim[0],
                self._model.z_lim[1])

    def set_plume_configuration(self, request):
        """
        Service function callback to set general plume configuration parameters.
        """
        if self._model is None:
            rospy.logwarn('No plume model has been created')
            return SetPlumeConfigurationResponse(False)
        try:
            if not self._model.set_n_points(request.n_points):
                return SetPlumeConfigurationRequest(False)

            rospy.loginfo('Change in the plume configuration:')
            rospy.loginfo('\t- # particles=%d' % request.n_points)

            if request.max_particles_per_iter > 0 and self._model.LABEL == 'passive_scalar_turbulence':
                if not self._model.set_max_particles_per_iter(request.max_particles_per_iter):
                    rospy.logerr('Error ocurred while setting the maximum number of particles per iteration, value=%d', int(request.max_particles_per_iter))
                    return SetPlumeConfigurationRequest(False)
                rospy.loginfo('\t- Max. number of particles per iteration=%d', request.max_particles_per_iter)
            return SetPlumeConfigurationResponse(True)
        except Exception, e:
            rospy.logerr('Error setting the plume configuration, message=%s', str(e))
            return SetPlumeConfigurationResponse(False)

    def set_plume_limits(self, request):
        """
        Service function callback to set the plume bounding box limits.
        """
        if self._model is None:
            rospy.logwarn('No plume model has been created')
            return SetPlumeLimitsResponse(False)

        try:
            self._model.set_x_lim(request.x_min, request.x_max)
            self._model.set_y_lim(request.y_min, request.y_max)
            self._model.set_z_lim(request.z_min, request.z_max)
            return SetPlumeLimitsResponse(True)
        except Exception, e:
            rospy.logerr('Error setting the plume limits, message=%s', str(e))
            return SetPlumeLimitsResponse(False)

    def create_spheroid_plume(self, request):
        """
        Service function callback to create a static spheroid plume model.
        """
        if request.a <= 0 or request.c <= 0:
            return CreateSpheroidPlumeResponse(False)
        if request.n_points <= 0:
            return CreateSpheroidPlumeResponse(False)

        try:
            q = np.array([request.orientation.x,
                          request.orientation.y,
                          request.orientation.z,
                          request.orientation.w])
            self._model = Plume.create_plume_model(
                'spheroid',
                request.a,
                request.c,
                q,
                [request.source.x, request.source.y, request.source.z],
                request.n_points,
                rospy.get_time())

            self._model.set_x_lim(request.x_min, request.x_max)
            self._model.set_y_lim(request.y_min, request.y_max)
            self._model.set_z_lim(request.z_min, request.z_max)
            rospy.loginfo('Spheroid plume created!')
            rospy.loginfo(request)
            return CreateSpheroidPlumeResponse(True)
        except Exception, e:
            rospy.logerr('Error creating spheroid plume model, message=%s', str(e))
            self._model = None
            return CreateSpheroidPlumeResponse(False)

    def create_passive_scalar_turbulent_plume(self, request):
        """
        Service function callback to create a passive scalar turbulent plume
        model.
        """
        try:
            self._model = Plume.create_plume_model(
                'passive_scalar_turbulence',
                [request.turbulent_diffusion_coefficients.x,
                 request.turbulent_diffusion_coefficients.y,
                 request.turbulent_diffusion_coefficients.z],
                request.buoyancy_flux,
                request.stability_param,
                [request.source.x, request.source.y, request.source.z],
                request.n_points,
                rospy.get_time(),
                request.max_particles_per_iter,
                request.max_life_time)

            self._model.set_x_lim(request.x_min, request.x_max)
            self._model.set_y_lim(request.y_min, request.y_max)
            self._model.set_z_lim(request.z_min, request.z_max)
            rospy.loginfo('Passive turbulent plume created!')
            rospy.loginfo('Turbulent diffusion coefficients: (%.3f, %.3f, %.3f)' % (
                request.turbulent_diffusion_coefficients.x,
                request.turbulent_diffusion_coefficients.y,
                request.turbulent_diffusion_coefficients.z))
            rospy.loginfo('Buoyancy flux: %.4f' % request.buoyancy_flux)
            rospy.loginfo('Stability parameters: %.4f' % request.stability_param)
            rospy.loginfo('Source position wrt ENU frame: (%.2f %.2f, %.2f)' % (
                request.source.x, request.source.y, request.source.z))
            rospy.loginfo('Max. number of particles: %d' % request.n_points)
            rospy.loginfo('Max. number of particles per iteration: %d' % request.max_particles_per_iter)
            rospy.loginfo('Particle max. life time: %.2f' % request.max_life_time)

            return CreatePassiveScalarTurbulentPlumeResponse(True)
        except Exception, e:
            rospy.logerr('Error creating passive turbulent plume model, message=%s', str(e))
            self._model = None
            return CreatePassiveScalarTurbulentPlumeResponse(False)

    def store_plume_state(self, request):
        rospy.loginfo('Store plume state service called')
        if self._model is None:
            rospy.logerr('No plume model available!')
            return StorePlumeStateResponse(False, '')
        output_dir = '/tmp'
        if os.path.isdir(request.output_dir):
            output_dir = request.output_dir
        filename = 'plume.yaml'
        if '.yaml' in request.filename or '.yml' in request.filename:
            filename = request.filename

        abs_filename = os.path.join(output_dir, filename)

        data = dict(x=self._model.x.flatten().tolist(),
                    y=self._model.y.flatten().tolist(),
                    z=self._model.z.flatten().tolist(),
                    time_creation=self._model.time_of_creation.flatten().tolist())

        with open(abs_filename, 'w') as yaml_file:
            yaml_file.write(yaml.safe_dump(data))

        rospy.loginfo('Plume particles stored in=%s', abs_filename)
        return StorePlumeStateResponse(True, abs_filename)

    def load_plume_particles(self, request):
        rospy.loginfo('Load plume particles service called')
        if self._model is None:
            rospy.logerr('No plume model available!')
            return LoadPlumeParticlesResponse(False)

        if not os.path.isfile(request.plume_file):
            rospy.logerr('Plume file provided is invalid, filename=%s', request.plume_file)
            return LoadPlumeParticlesResponse(False)
        self._loading_plume.acquire()
        try:
            with open(request.plume_file, 'r') as plume_file:
                plume_particles = yaml.load(plume_file)
            rospy.loginfo('# particles loaded=%d', len(plume_particles['x']))
            self._model.set_plume_particles(
                rospy.get_time(),
                plume_particles['x'],
                plume_particles['y'],
                plume_particles['z'],
                plume_particles['time_creation'])
        except Exception, e:
            rospy.logerr('Error while loading plume file, message=%s', str(e))
            return LoadPlumeParticlesResponse(False)

        self._loading_plume.release()
        return LoadPlumeParticlesResponse(True)

    def update_plume(self, event):
        """
        Callback function for the plume timer which contains the update for
        the plume particle point cloud and visual markers for RViz.
        """
        self._dt = rospy.get_time() - self._t
        self._loading_plume.acquire()
        if self._dt <= 0.0:
            return True

        if self._model is None:
            marker = MarkerArray()
            self._plume_marker_publisher.publish(marker)
            pc_msg = PointCloud()
            pc_msg.header.stamp = rospy.Time.now()
            pc_msg.header.frame_id = 'world'
            self._plume_point_cloud_publisher.publish(pc_msg)
            self._loading_plume.release()
            return True

        if not self._model.update(rospy.get_time()):
            rospy.logerr('Error while updating the plume particles positions')
            self._loading_plume.release()
            return True

        marker = self._model.get_markers()

        if marker is not None:
            self._plume_marker_publisher.publish(marker)

        pc_msg = self._model.get_point_cloud_as_msg()
        if pc_msg is not None:
            self._plume_point_cloud_publisher.publish(pc_msg)
        self._loading_plume.release()
        return True


if __name__ == '__main__':
    print 'Plume simulator server'
    rospy.init_node('plume_simulator_server')

    try:
        pb = PlumeSimulatorServer()
        rospy.spin()
    except rospy.ROSInterruptException:
        print('caught exception')
