#pragma once

#include <algorithm>
#include <array>
#include <chrono>
#include <iostream>
#include <limits>
#include <math.h>
#include <numeric>
#include <optional>
#include <tuple>

#include <ruckig/calculator.hpp>
#include <ruckig/input_parameter.hpp>
#include <ruckig/output_parameter.hpp>
#include <ruckig/trajectory.hpp>


namespace ruckig {

//! Main class for the Ruckig algorithm.
template<size_t DOFs = 0, template<class, size_t> class CustomVector = StandardVector, bool throw_error = false>
class Ruckig {
    //! Current input, only for comparison for recalculation
    InputParameter<DOFs, CustomVector> current_input;

    //! Flag that indicates if the current_input was properly initialized
    bool current_input_initialized {false};

    inline static double v_at_a_zero(double v0, double a0, double j) {
        return v0 + (a0 * a0)/(2 * j);
    }

public:
    //! Calculator for new trajectories
    Calculator<DOFs, CustomVector> calculator;

    //! Max number of intermediate waypoints
    const size_t max_number_of_waypoints;

    //! Degrees of freedom
    const size_t degrees_of_freedom;

    //! Time step between updates (cycle time) in [s]
    double delta_time {0.0};

    template <size_t D = DOFs, typename std::enable_if<D >= 1, int>::type = 0>
    explicit Ruckig(): max_number_of_waypoints(0), degrees_of_freedom(DOFs), delta_time(-1.0) {
    }

    template <size_t D = DOFs, typename std::enable_if<D >= 1, int>::type = 0>
    explicit Ruckig(double delta_time): max_number_of_waypoints(0), degrees_of_freedom(DOFs), delta_time(delta_time) {
    }

#if defined WITH_ONLINE_CLIENT
    template <size_t D = DOFs, typename std::enable_if<D >= 1, int>::type = 0>
    explicit Ruckig(double delta_time, size_t max_number_of_waypoints): current_input(InputParameter<DOFs, CustomVector>(max_number_of_waypoints)), calculator(Calculator<DOFs, CustomVector>(max_number_of_waypoints)), max_number_of_waypoints(max_number_of_waypoints), degrees_of_freedom(DOFs), delta_time(delta_time) {
    }
#endif

    template <size_t D = DOFs, typename std::enable_if<D == 0, int>::type = 0>
    explicit Ruckig(size_t dofs): current_input(InputParameter<DOFs, CustomVector>(dofs)), calculator(Calculator<DOFs, CustomVector>(dofs)), max_number_of_waypoints(0), degrees_of_freedom(dofs), delta_time(-1.0) {
    }

    template <size_t D = DOFs, typename std::enable_if<D == 0, int>::type = 0>
    explicit Ruckig(size_t dofs, double delta_time): current_input(InputParameter<DOFs, CustomVector>(dofs)), calculator(Calculator<DOFs, CustomVector>(dofs)), max_number_of_waypoints(0), degrees_of_freedom(dofs), delta_time(delta_time) {
    }

#if defined WITH_ONLINE_CLIENT
    template <size_t D = DOFs, typename std::enable_if<D == 0, int>::type = 0>
    explicit Ruckig(size_t dofs, double delta_time, size_t max_number_of_waypoints): current_input(InputParameter<DOFs, CustomVector>(dofs, max_number_of_waypoints)), calculator(Calculator<DOFs, CustomVector>(dofs, max_number_of_waypoints)), max_number_of_waypoints(max_number_of_waypoints), degrees_of_freedom(dofs), delta_time(delta_time) {
    }
#endif

    //! Reset the instance (e.g. to force a new calculation in the next update)
    void reset() {
        current_input_initialized = false;
    }

    //! Filter intermediate positions based on a threshold distance for each DoF
    template<class T> using Vector = CustomVector<T, DOFs>;
    std::vector<Vector<double>> filter_intermediate_positions(const InputParameter<DOFs, CustomVector>& input, const Vector<double>& threshold_distance) const {
        if (input.intermediate_positions.empty()) {
            return input.intermediate_positions;
        }

        const size_t n_waypoints = input.intermediate_positions.size();
        std::vector<bool> is_active;
        is_active.resize(n_waypoints);
        for (size_t i = 0; i < n_waypoints; ++i) {
            is_active[i] = true;
        }

        size_t start = 0;
        size_t end = start + 2;
        for (;end < n_waypoints + 2; ++end) {
            const auto pos_start = (start == 0) ? input.current_position : input.intermediate_positions[start-1];
            const auto pos_end = (end == n_waypoints+1) ? input.target_position : input.intermediate_positions[end-1];
            
            // Check for all intermediate positions
            bool are_all_below {true};
            for (size_t current = start + 1; current < end; ++current) {
                const auto pos_current = input.intermediate_positions[current-1];

                // Is there a point t on the line that holds the threshold?
                double t_start_max = 0.0;
                double t_end_min = 1.0;
                for (size_t dof = 0; dof < degrees_of_freedom; ++dof) {
                    const double h0 = (pos_current[dof] - pos_start[dof]) / (pos_end[dof] - pos_start[dof]);
                    const double t_start = h0 - threshold_distance[dof] / std::abs(pos_end[dof] - pos_start[dof]);
                    const double t_end = h0 + threshold_distance[dof] / std::abs(pos_end[dof] - pos_start[dof]);

                    t_start_max = std::max(t_start, t_start_max);
                    t_end_min = std::min(t_end, t_end_min);

                    if (t_start_max > t_end_min) {
                        are_all_below = false;
                        break;
                    }
                }
                if (!are_all_below) {
                    break;
                }
            }

            is_active[end-2] = !are_all_below;
            if (!are_all_below) {
                start = end - 1;
            }
        }
        
        std::vector<Vector<double>> filtered_positions;
        filtered_positions.reserve(n_waypoints);
        for (size_t i = 0; i < n_waypoints; ++i) {
            if (is_active[i]) {
                filtered_positions.push_back(input.intermediate_positions[i]);
            }
        }

        return filtered_positions;
    }

    //! Validate the input for trajectory calculation and kinematic limits
    bool validate_input(const InputParameter<DOFs, CustomVector>& input, bool check_current_state_within_limits=false, bool check_target_state_within_limits=true) const {
        for (size_t dof = 0; dof < degrees_of_freedom; ++dof) {
            const double jMax = input.max_jerk[dof];
            if (std::isnan(jMax) || jMax <= std::numeric_limits<double>::min()) {
                return false;
            }

            const double aMax = input.max_acceleration[dof];
            if (std::isnan(aMax) || aMax <= std::numeric_limits<double>::min()) {
                return false;
            }

            const double aMin = input.min_acceleration ? input.min_acceleration.value()[dof] : -input.max_acceleration[dof];
            if (std::isnan(aMin) || aMin >= -std::numeric_limits<double>::min()) {
                return false;
            }

            const double a0 = input.current_acceleration[dof];
            const double af = input.target_acceleration[dof];
            if (std::isnan(a0) || std::isnan(af)) {
                return false;
            }

            if (check_current_state_within_limits && (a0 > aMax || a0 < aMin)) {
                return false;
            }
            if (check_target_state_within_limits && (af > aMax || af < aMin)) {
                return false;  
            }

            const double v0 = input.current_velocity[dof];
            const double vf = input.target_velocity[dof];
            if (std::isnan(v0) || std::isnan(vf)) {
                return false;
            }

            auto control_interface = input.per_dof_control_interface ? input.per_dof_control_interface.value()[dof] : input.control_interface;
            if (control_interface == ControlInterface::Position) {
                const double p0 = input.current_position[dof];
                const double pf = input.target_position[dof];
                if (std::isnan(p0) || std::isnan(pf)) {
                    return false;
                }

                const double vMax = input.max_velocity[dof];
                const double vMin = input.min_velocity ? input.min_velocity.value()[dof] : -input.max_velocity[dof];

                if (std::isnan(vMax) || vMax <= std::numeric_limits<double>::min()) {
                    return false;
                }

                if (std::isnan(vMin) || vMin >= -std::numeric_limits<double>::min()) {
                    return false;
                }

                if (check_current_state_within_limits && (v0 > vMax || v0 < vMin)) {
                    return false;
                }
                if (check_target_state_within_limits && (vf > vMax || vf < vMin)) {
                    return false;
                }

                if (check_current_state_within_limits && ((a0 > 0 && v_at_a_zero(v0, a0, jMax) > vMax) || (a0 < 0 && v_at_a_zero(v0, a0, -jMax) < vMin))) {
                    return false;
                }
                if (check_target_state_within_limits && ((af < 0 && v_at_a_zero(vf, af, jMax) > vMax) || (af > 0 && v_at_a_zero(vf, af, -jMax) < vMin))) {
                    return false;
                }
            }
        }

        if (!input.intermediate_positions.empty() && input.control_interface == ControlInterface::Position) {
            if (input.intermediate_positions.size() > max_number_of_waypoints) {
                return false;
            }
            
            if (input.minimum_duration || input.duration_discretization != DurationDiscretization::Continuous) {
                return false;
            }

            if (input.per_dof_control_interface || input.per_dof_synchronization) {
                return false;
            }
        }

        if (delta_time <= 0.0 && input.duration_discretization != DurationDiscretization::Continuous) {
            return false;
        }

        return true;
    }

    //! Calculate a new trajectory for the given input
    Result calculate(const InputParameter<DOFs, CustomVector>& input, Trajectory<DOFs, CustomVector>& trajectory) {
        bool was_interrupted {false};
        return calculate(input, trajectory, was_interrupted);
    }

    //! Calculate a new trajectory for the given input and check for interruption
    Result calculate(const InputParameter<DOFs, CustomVector>& input, Trajectory<DOFs, CustomVector>& trajectory, bool& was_interrupted) {
        if (!validate_input(input, false, true)) {
            return Result::ErrorInvalidInput;
        }

        return calculator.template calculate<throw_error>(input, trajectory, delta_time, was_interrupted);
    }

    //! Get the next output state (with step delta_time) along the calculated trajectory for the given input
    Result update(const InputParameter<DOFs, CustomVector>& input, OutputParameter<DOFs, CustomVector>& output) {
        const auto start = std::chrono::steady_clock::now();

        if constexpr (DOFs == 0 && throw_error) {
            if (degrees_of_freedom != input.degrees_of_freedom || degrees_of_freedom != output.degrees_of_freedom) {
                throw std::runtime_error("[ruckig] mismatch in degrees of freedom (vector size).");
            }
        }

        output.new_calculation = false;

        Result result {Result::Working};
        if (input != current_input || !current_input_initialized) {
            result = calculate(input, output.trajectory, output.was_calculation_interrupted);
            if (result != Result::Working && result != Result::ErrorPositionalLimits) {
                return result;
            }

            current_input = input;
            current_input_initialized = true;
            output.time = 0.0;
            output.new_calculation = true;
        }

        const size_t old_section = output.new_section;
        output.time += delta_time;
        output.trajectory.at_time(output.time, output.new_position, output.new_velocity, output.new_acceleration, output.new_section);
        output.did_section_change = (output.new_section != old_section);

        const auto stop = std::chrono::steady_clock::now();
        output.calculation_duration = std::chrono::duration_cast<std::chrono::nanoseconds>(stop - start).count() / 1000.0;

        output.pass_to_input(current_input);

        if (output.time > output.trajectory.get_duration()) {
            return Result::Finished;
        }

        return result;
    }
};


template<size_t DOFs, template<class, size_t> class CustomVector = StandardVector>
using RuckigThrow = Ruckig<DOFs, CustomVector, true>;


} // namespace ruckig
