#!/bin/bash
if "true" : '''\'
then
python${ROS_PYTHON_VERSION:-} "${BASH_SOURCE[0]}" $*
exit
fi
'''
# flake8: noqa
import sys
import yaml
import argparse
from collections import defaultdict

def main():
  parser = argparse.ArgumentParser(description="Evaluate a logged behavior execution.")
  parser.add_argument('logfile', type=str, help="The log file to be evaluated.")
  parser.add_argument('key', type=str,
                      help="Defines what should be used as key for grouping logs. Choices are:\n"
                           "[total] Determine a single value for the complete behavior execution.\n"
                           "[state_path] One value for each individual state.\n"
                           "[state_name] Group states with the same name.\n"
                           "[state_class] One value for each class, i.e., group same state types.\n"
                           "[transition] Each outcome is handled separately for each state.\n")
  parser.add_argument('values', type=str, nargs='+',
                      help="List of calculated evaluation values, result is sorted by first value specified."
                           "Choices are:\n"
                           "[time] Execution time, i.e., time of being the active state.\n"
                           "[time_wait] Part of execution time of waiting for a requested outcome.\n"
                           "[time_run] Part of execution time of actually executing.\n"
                           "[visits] Number of state visits.\n"
                           "[interact] Number of interactions, e.g., requesting and forcing outcomes.\n"
                           "[count_wait] Counts how often the state needed to wait for the operator.\n"
                           "[correct] Number of requested outcomes that got approved by the operator.\n"
                           "[wrong] Number of forced outcomes that are different from requested ones.\n")
  args = parser.parse_args()

  # load logfile

  try:
      with open(args.logfile) as f:
          logs = yaml.safe_load(f)
  except Exception as e:
      print('Failed to load logfile "%s":\n%s' % (str(args.logfile), str(e)))
      sys.exit(1)

  if logs is None:
      print('Given logfile is empty, please evaluate a different one.')
      sys.exit(1)

  behavior = "unknown"
  execution_time = None
  state_history = list()
  active_states = dict()

  # process logs to construct state execution instances

  for log in logs:
      if log['logger'] == 'flexbe.initialize':
          behavior = log['behavior']
          execution_time = log['time']

      elif log['logger'] == 'flexbe.events' and log['event'] == 'enter':
          active_states[log['path']] = defaultdict(int)
          active_states[log['path']]['state_name'] = log['name']
          active_states[log['path']]['state_path'] = log['path']
          active_states[log['path']]['state_class'] = log['state']
          active_states[log['path']]['begin'] = log['time']
          active_states[log['path']]['visits'] = 1

      elif log['logger'] == 'flexbe.events' and log['event'] == 'exit':
          active_states[log['path']]['time'] = log['time'] - active_states[log['path']]['begin']
          del active_states[log['path']]['begin']
          if 'begin_wait' in active_states[log['path']]:
              active_states[log['path']]['time_wait'] = log['time'] - active_states[log['path']]['begin_wait']
              del active_states[log['path']]['begin_wait']
              active_states[log['path']]['time_run'] = (active_states[log['path']]['time'] -
                                                        active_states[log['path']]['time_wait'])
          else:
              active_states[log['path']]['time_wait'] = 0
              active_states[log['path']]['time_run'] = active_states[log['path']]['time']
          state_history.append(active_states[log['path']])
          del active_states[log['path']]

      elif log['logger'] == 'flexbe.outcomes':
          active_states[log['path']]['outcome'] = log['outcome']

      elif log['logger'] == 'flexbe.operator' and log['type'] == 'forced':
          active_states[log['path']]['interact'] += 1
          if log['forced'] == log['requested']:
              active_states[log['path']]['correct'] += 1
          else:
              active_states[log['path']]['wrong'] += 1

      elif log['logger'] == 'flexbe.operator' and log['type'] == 'request':
          active_states[log['path']]['count_wait'] += 1
          active_states[log['path']]['begin_wait'] = log['time']

  if len(active_states) > 0:
      print('Log file appears incomplete, ignoring execution of states:\n%s' % ', '.join(active_states.keys()))

  # combine state execution instances according to selected key and values

  try:
      key_definitions = {
          'total': lambda s: "TOTAL",
          'state_path': lambda s: s['state_path'],
          'state_name': lambda s: s['state_name'],
          'state_class': lambda s: s['state_class'],
          'transition': lambda s: "%s > %s" % (s['state_path'], s['outcome'])
      }
      key_function = key_definitions[args.key]
  except KeyError:
      print('Invalid key selected: %s\nAvailable: %s' % (str(args.key), ', '.join(key_definitions.keys())))
      sys.exit(1)

  values = args.values
  try:
      accumulated = dict()
      for state in state_history:
          key = key_function(state)
          if key not in accumulated:
              accumulated[key] = defaultdict(int)
          for value in values:
              accumulated[key][value] += state[value]
  except KeyError:
      keys = set(state.keys()) - {'state_name', 'state_path', 'state_class'}
      print('Invalid value selected: %s\nAvailable: %s' % (str(args.key), ', '.join(keys)))
      sys.exit(1)

  # display evaluation result

  sorted_keys = sorted(accumulated.keys(), key=lambda k: accumulated[k][values[0]], reverse=True)

  print('Evaluation of behavior "%s"' % behavior)
  print(("{:^25} " + "{:^10} " * len(values)).format('key', *values))
  for key in sorted_keys:
      line = "{:<25.25} ".format(key if len(key) <= 25 else '..'+key[-23:])
      for value in values:
          value = accumulated[key][value]
          if isinstance(value, float):
              line += "{:>10.3f} ".format(value)
          else:
              line += "{:>10} ".format(value)
      print(line)


if __name__ == '__main__':
  main()
