#!/bin/bash
if "true" : '''\'
then
python${ROS_PYTHON_VERSION:-} "${BASH_SOURCE[0]}" $*
exit
fi
'''
# flake8: noqa
import sys
import yaml
import argparse
from collections import defaultdict


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Evaluate a logged behavior execution.")
    parser.add_argument('logfile', type=str, help="The log file to be evaluated.")
    parser.add_argument('key', type=str,
                        help="Defines what should be used as key for grouping logs. Choices are:\n"
                             "[total] Determine a single value for the complete behavior execution.\n"
                             "[state_path] One value for each individual state.\n"
                             "[state_name] Group states with the same name.\n"
                             "[state_class] One value for each class, i.e., group same state types.\n"
                             "[transition] Each outcome is handled separately for each state.\n")
    parser.add_argument('values', type=str, nargs='+',
                        help="List of calculated evaluation values, result is sorted by first value specified."
                             "Choices are:\n"
                             "[time] Execution time, i.e., time of being the active state.\n"
                             "[time_wait] Part of execution time of waiting for a requested outcome.\n"
                             "[time_run] Part of execution time of actually executing.\n"
                             "[visits] Number of state visits.\n"
                             "[interact] Number of interactions, e.g., requesting and forcing outcomes.\n"
                             "[count_wait] Counts how often the state needed to wait for the operator.\n"
                             "[correct] Number of requested outcomes that got approved by the operator.\n"
                             "[wrong] Number of forced outcomes that are different from requested ones.\n")
    args = parser.parse_args()

    # load logfile

    try:
        with open(args.logfile) as f:
            logs = yaml.safe_load(f)
    except Exception as e:
        print('Failed to load logfile "%s":\n%s' % (str(args.logfile), str(e)))
        sys.exit(1)

    if logs is None:
        print('Given logfile is empty, please evaluate a different one.')
        sys.exit(1)

    behavior = "unknown"
    execution_time = None
    state_history = list()
    active_states = dict()

    # process logs to construct state execution instances

    for log in logs:
        if log['logger'] == 'flexbe.initialize':
            behavior = log['behavior']
            execution_time = log['time']

        elif log['logger'] == 'flexbe.events' and log['event'] == 'enter':
            active_states[log['path']] = defaultdict(int)
            active_states[log['path']]['state_name'] = log['name']
            active_states[log['path']]['state_path'] = log['path']
            active_states[log['path']]['state_class'] = log['state']
            active_states[log['path']]['begin'] = log['time']
            active_states[log['path']]['visits'] = 1

        elif log['logger'] == 'flexbe.events' and log['event'] == 'exit':
            active_states[log['path']]['time'] = log['time'] - active_states[log['path']]['begin']
            del active_states[log['path']]['begin']
            if 'begin_wait' in active_states[log['path']]:
                active_states[log['path']]['time_wait'] = log['time'] - active_states[log['path']]['begin_wait']
                del active_states[log['path']]['begin_wait']
                active_states[log['path']]['time_run'] = (active_states[log['path']]['time'] -
                                                          active_states[log['path']]['time_wait'])
            else:
                active_states[log['path']]['time_wait'] = 0
                active_states[log['path']]['time_run'] = active_states[log['path']]['time']
            state_history.append(active_states[log['path']])
            del active_states[log['path']]

        elif log['logger'] == 'flexbe.outcomes':
            active_states[log['path']]['outcome'] = log['outcome']

        elif log['logger'] == 'flexbe.operator' and log['type'] == 'forced':
            active_states[log['path']]['interact'] += 1
            if log['forced'] == log['requested']:
                active_states[log['path']]['correct'] += 1
            else:
                active_states[log['path']]['wrong'] += 1

        elif log['logger'] == 'flexbe.operator' and log['type'] == 'request':
            active_states[log['path']]['count_wait'] += 1
            active_states[log['path']]['begin_wait'] = log['time']

    if len(active_states) > 0:
        print('Log file appears incomplete, ignoring execution of states:\n%s' % ', '.join(active_states.keys()))

    # combine state execution instances according to selected key and values

    try:
        key_definitions = {
            'total': lambda s: "TOTAL",
            'state_path': lambda s: s['state_path'],
            'state_name': lambda s: s['state_name'],
            'state_class': lambda s: s['state_class'],
            'transition': lambda s: "%s > %s" % (s['state_path'], s['outcome'])
        }
        key_function = key_definitions[args.key]
    except KeyError:
        print('Invalid key selected: %s\nAvailable: %s' % (str(args.key), ', '.join(key_definitions.keys())))
        sys.exit(1)

    values = args.values
    try:
        accumulated = dict()
        for state in state_history:
            key = key_function(state)
            if key not in accumulated:
                accumulated[key] = defaultdict(int)
            for value in values:
                accumulated[key][value] += state[value]
    except KeyError:
        keys = set(state.keys()) - {'state_name', 'state_path', 'state_class'}
        print('Invalid value selected: %s\nAvailable: %s' % (str(args.key), ', '.join(keys)))
        sys.exit(1)

    # display evaluation result

    sorted_keys = sorted(accumulated.keys(), key=lambda k: accumulated[k][values[0]], reverse=True)

    print('Evaluation of behavior "%s"' % behavior)
    print(("{:^25} " + "{:^10} " * len(values)).format('key', *values))
    for key in sorted_keys:
        line = "{:<25.25} ".format(key if len(key) <= 25 else '..'+key[-23:])
        for value in values:
            value = accumulated[key][value]
            if isinstance(value, float):
                line += "{:>10.3f} ".format(value)
            else:
                line += "{:>10} ".format(value)
        print(line)
