#!/usr/bin/env python

import roslib; roslib.load_manifest('roslaunch')

import os
import sys
import curses

from roslib.msg import Log
import roslib.network
import rospy
import roslaunch.netapi as netapi
from roslaunch.netapi import NetProcess

## Node name
NAME='roslaunch-console'

## strify our data (NetProcess)
def datastr(p):
    if p.is_alive:
        if p.respawn_count > 1:
            return pad("%s (respawn %s)"%(p.name, p.respawn_count-1))
        else:
            return pad("%s"%p.name)
    else:
        return pad("dead:%s"%p.name)

fake_data = [
    NetProcess('node1', 1, True, 'http://foo:1234'),
    NetProcess('node2', 1, True, 'http://foo:1234'),
    NetProcess('node3', 1, True, 'http://foo:1234'),
    NetProcess('node4', 1, True, 'http://foo:1234'),
    NetProcess('node5', 1, False, 'http://foo:1234'),
    NetProcess('node6', 1, True, 'http://foo:1234'),    
    NetProcess('deadnode7', 11234, False, 'http://foo:1234'),    
    NetProcess('/a/much/longer/name/I/must/persist/to/write/to/screen', 1, True, 'http://foo:1234'),

    NetProcess('m2node1', 1, True, 'http://m2:1234'),
    NetProcess('m2node2', 1, True, 'http://m2:1234'),
    NetProcess('m2node3', 1, True, 'http://m2:1234'),
    NetProcess('m2node4', 4, True, 'http://m2:1234'),
    NetProcess('m2node5', 1, True, 'http://m2:1234'),
    NetProcess('m2node6', 1, True, 'http://m2:1234'),    

    NetProcess('m3node1', 1, True, 'http://m3:1234'),
    NetProcess('m3node2', 1, True, 'http://m3:1234'),
    NetProcess('m3node3', 1, True, 'http://m3:1234'),
    NetProcess('m3node4', 1, False, 'http://m3:1234'),
    NetProcess('m3node5', 1, False, 'http://m3:1234'),
    NetProcess('m3node6', 1, True, 'http://m3:1234'),    

    NetProcess('m4node1', 1, True, 'http://m4:1234'),
    NetProcess('m4node2', 1, True, 'http://m4:1234'),
    NetProcess('m4node3', 1, True, 'http://m4:1234'),
    NetProcess('m4node4', 1, True, 'http://m4:1234'),
    NetProcess('m4node5', 1, True, 'http://m4:1234'),
    NetProcess('m4node6', 1, True, 'http://m4:1234'),    
    ]

def pad(s):
    if len(s) < curses.COLS:
        return s + ' '*(curses.COLS-len(s)-1)
    else:
        return s[:curses.COLS-1]

class RLConsoleException(Exception): pass

class ROSLaunchConsole(object):
    
    def __init__(self, stdscr, roslaunch_uris=None, fake=False):
        if fake:
            self.roslaunch_uris = []
            self.data = fake_data
        elif not roslaunch_uris:
            roslaunch_uris = netapi.get_roslaunch_uris()
            if not roslaunch_uris:
                raise RLConsoleException("Cannot find any launches registered with the Parameter Server")
            self.data = []
            
        self.draw_model = []

        self.stdscr = stdscr
        self.scr_line = 0 #index of line on screen that is selected
        self.draw_model_line = 0 #first draw_model_line on screen
        
        self.reserved = 2 # total number of lines we reserve at top and bottom of screen
        self.reserved_top = 1 #number of lines we reserve at top of screen

        self.update_data()

        self.loggers = {}
        self.loggers[Log.FATAL] = []        
        self.loggers[Log.ERROR] = []        
        self.loggers[Log.WARN] = []        
        self.loggers[Log.INFO] = []
        self.loggers[Log.DEBUG] = []

        #self.command_context_stack = [self.base]
        
    #-------------------------------------------------------------------------------
    # Logging methods, decoupled from ROS
        
    def add_logger(self, level, fn):
        self.loggers[level].append(fn)

    def log(self, level, msg, *args):
        for l in self.loggers[level]:
            l(msg, *args)
        
    def logerr(self, msg, *args):
        self.log(Log.ERROR, msg, *args)
    
    def logwarn(self, msg, *args):
        self.log(Log.WARN, msg, *args)

    def logdebug(self, msg, *args):
        self.log(Log.DEBUG, msg, *args)

    def loginfo(self, msg, *args):
        self.log(Log.INFO, msg, *args)

    def logfatal(self, msg, *args):
        self.log(Log.FATAL, msg, *args)


    #-------------------------------------------------------------------------------
    # Console Data routines

    def update_data(self):
        # TODO: ordering of self.data must be preserved, so cannot actually replace list
        self.data = netapi.list_processes()
        self.draw_model = self.data # for now, equal. need to decouple
        # Need to have two models: the draw_model and the actual data_model
        #self.next_data = next_data
        # on redraw, self.data = self.next_data

    #-------------------------------------------------------------------------------
    # Console Drawing API

    def get_selected(self):
        return self.draw_model[self.draw_model_line + self.scr_line]

    def draw_data(self, d, row, selected):
        scr = self.stdscr
        if selected:
            if d.is_alive:
                scr.addstr(row, 0, datastr(d), curses.A_REVERSE)
            else:
                scr.addstr(row, 0, datastr(d), curses.A_REVERSE | curses.A_BOLD)                    
        else:
            if d.is_alive:
                scr.addstr(row, 0, datastr(d))
            else:
                scr.addstr(row, 0, datastr(d), curses.A_BOLD) 
            
    def redraw_scr(self):
        self.loginfo("redraw")
        
        model = self.draw_model
        my_lines = curses.LINES - self.reserved

        scr = self.stdscr
        scr.clear()
        dead_count = len([n for n in model if not n.is_alive])
        for i, l in enumerate(model[self.draw_model_line:self.draw_model_line+my_lines]):
            # offset draw row by the reserved_top
            draw_i = i + self.reserved_top
            self.draw_data(l, draw_i, i == self.scr_line)

        show_count = curses.LINES - self.reserved
        if show_count > len(model):
           show_count = len(model)
        status = "%s of %s nodes, %s dead"%(show_count, len(model), dead_count)
        scr.addstr(0, 0, pad(status), curses.A_REVERSE | curses.A_BOLD)
        #scr.addstr(curses.LINES-1, 0, pad('type h for help'), curses.A_BOLD)
        scr.addstr(curses.LINES-1, 0, pad('space=update, s=sort, r=relaunch, k=kill, q=quit'), curses.A_BOLD)                
        
    def scroll_scr(self, inc):
        model = self.draw_model

        my_lines = curses.LINES - self.reserved
        next_line = self.scr_line + inc
        if (next_line < 0 and self.draw_model_line == 0) or \
                (next_line + self.draw_model_line >= len(model)):
            # ignore out-of-bounds
            return

        scr = self.stdscr
        
        if next_line < 0:
            self.draw_model_line -= 1
            self.scr_line = 0
            self.redraw_scr()
        elif next_line >= my_lines:
            self.draw_model_line += 1
            self.scr_line = my_lines - 1
            self.redraw_scr()
        else:
            offset = self.reserved_top
            # unhighlight old line
            d = model[self.scr_line + self.draw_model_line]
            self.draw_data(d, offset+self.scr_line, False)

            # draw highlighted
            self.scr_line = next_line
            d = model[self.scr_line + self.draw_model_line]
            self.draw_data(d, offset+self.scr_line, True)
            scr.refresh()

    #-------------------------------------------------------------------------------
    # Console Commands

    def update_cmd(self):
        self.loginfo("update cmd")
        self.update_data()
        self.redraw_scr()

    def help_cmd(self):
        # highlighted line may be a node or a machine
        # TODO: compute actual width
        self.stdscr.addstr(curses.LINES-1, 0, pad('space=update, s=sort, k=kill'), curses.A_BOLD)        
        self.stdscr.refresh()

    def relaunch_cmd(self):
        # highlighted line may be a node or a machine
        self.loginfo("relaunch cmd")

        selected = self.get_selected()
        self.loginfo("selected is %s", selected)
        
    def kill_cmd(self):
        # highlighted line may be a node or a machine
        self.loginfo("relaunch cmd")
    
    def sort_cmd(self):
        # input: sort by status, name, machine
        self.loginfo("relaunch cmd")
    

## @return scr: initialized and configured curses screen
def _init_curses():
    stdscr = curses.initscr()
    # interpret special keys as curses values (e.g. KEY_LEFT)
    stdscr.keypad(1)

    # react immediately to keypress
    curses.cbreak()
    # hide the cursor
    curses.curs_set(0)
    # don't echo keypresses to screen
    curses.noecho()
    return stdscr
    
## Reset terminal settings (teardown curses)
## @param stdscr: screen to tear down
def _deinit_curses(stdscr):
    curses.curs_set(1)
    curses.nocbreak()
    if stdscr is not None:
        stdscr.keypad(0)
    curses.echo()
    curses.endwin()
    
def roslaunch_console():
    stdscr = None
    try:
        stdscr = _init_curses()
        
        rospy.init_node(NAME, anonymous=True)

        console = ROSLaunchConsole(stdscr)

        console.add_logger(Log.INFO, rospy.loginfo)
        console.add_logger(Log.FATAL, rospy.logerr)
        console.add_logger(Log.ERROR, rospy.logerr)
        console.add_logger(Log.DEBUG, rospy.logdebug)          
            
        exit = False
        console.redraw_scr()
        while not exit:
            ch = stdscr.getch()
            if ch == curses.KEY_UP:
                console.scroll_scr(-1)
            elif ch == curses.KEY_DOWN:
                console.scroll_scr(1)
            else:
                if ch == ord('q'):
                    exit = True
                elif ch == ord('k'):
                    console.kill_cmd()
                elif ch == ord(' '):
                    console.update_cmd()
                elif ch == ord('s'):
                    console.sort_cmd()                    
        
    finally:
        _deinit_curses(stdscr)

def roslaunch_console_main():
    try:
        roslaunch_console()
    except RLConsoleException, e:
        print >> sys.stderr, e
    except KeyboardInterrupt: pass
    
if __name__ == '__main__':
    roslaunch_console_main()

