#!/usr/bin/env python
"""
Script capturing an ORB template (to then use in orb_track or capture)
"""
from ecto.opts import scheduler_options, run_plasm
from ecto_image_pipeline.io.source import add_camera_group, create_source
from ecto_opencv.calib import PoseDrawer, DepthValidDraw
from ecto_opencv.features2d import DrawKeypoints
from ecto_opencv.highgui import imshow, FPSDrawer, MatWriter, ImageSaver
from ecto_opencv.imgproc import cvtColor, Conversion
from ecto_opencv.rgbd import PlaneFinder, ComputeNormals
from object_recognition_capture.ecto_cells.capture import PlaneFilter, FeatureFinder
from object_recognition_capture.orb_capture import *
import argparse
import ecto
import os

def parse_args():
    parser = argparse.ArgumentParser(description='Save an ORB template that may be used as a '
                                                 'fiducial marker when pressing the key "s". "q" to quit')
    parser.add_argument('-o,--output', dest='output', type=str,
                        help='The output directory for this template. Default: %(default)s', default='./')
    parser.add_argument('-n_features', dest='n_features', type=int,
                        help='The number of features to detect for the template.,%(default)d',
                        default=5000)    
    scheduler_options(parser.add_argument_group('Scheduler'))
    add_camera_group(parser)
    options = parser.parse_args()
    if not os.path.exists(options.output):
        os.makedirs(options.output)
    return options


options = parse_args()
plasm = ecto.Plasm()

#setup the input source, grayscale conversion
source = create_source('image_pipeline', 'OpenNISource', outputs_list=['K_depth', 'K_image', 'image', 'depth', 'mask_depth', 'points3d'], res=options.res, fps=options.fps)
rgb2gray = cvtColor (flag=Conversion.RGB2GRAY)
plasm.connect(source['image'] >> rgb2gray ['image'])

#convenience variable for the grayscale
img_src = rgb2gray['image']

#display the depth
plasm.connect(source['depth'] >> imshow(name='depth')[:],
              )

#connect up the test ORB
orb = FeatureFinder('ORB test', n_features=options.n_features, n_levels=3, scale_factor=1.2)
plasm.connect(img_src >> orb['image'],
              source['points3d'] >> orb['points3d'],
              source['mask_depth'] >> orb['points3d_mask']
              )


#display test ORB
draw_kpts = DrawKeypoints()
fps = FPSDrawer()
orb_display = imshow('orb display', name='ORB', triggers=dict(save=ord('s')))
depth_valid_draw = DepthValidDraw()
plasm.connect(orb['keypoints'] >> draw_kpts['keypoints'],
              source['image', 'mask_depth'] >> depth_valid_draw['image', 'mask'],
              depth_valid_draw['image'] >> draw_kpts['image'],
              draw_kpts['image'] >> fps[:],
             )

plane_est = PlaneFinder(min_size=10000)
compute_normals = ComputeNormals()
pose_draw = PoseDrawer()
plane_filter = PlaneFilter(do_center=True)
plasm.connect(# find the normals
              source['K_depth', 'points3d'] >> compute_normals['K', 'points3d'],
              # find the planes
              compute_normals['normals'] >> plane_est['normals'],
              source['K_depth', 'points3d'] >> plane_est['K', 'points3d'],
              # only choose the most centered plane
              plane_est['planes', 'masks'] >> plane_filter['planes', 'masks'],
              # draw the pose of the plane
              plane_filter['R', 'T'] >> pose_draw['R', 'T'],
              source['K_image'] >> pose_draw['K'],
              fps[:] >> pose_draw['image'],
              pose_draw['output'] >> orb_display['image']
              )

#training 
points3d_writer = ecto.If("Points3d writer", cell=MatWriter(filename=os.path.join(options.output, 'points3d.yaml')))
points_writer = ecto.If("Points writer", cell=MatWriter(filename=os.path.join(options.output, 'points.yaml')))
descriptor_writer = ecto.If("Descriptor writer", cell=MatWriter(filename=os.path.join(options.output, 'descriptors.yaml')))
R_writer = ecto.If("R writer", cell=MatWriter(filename=os.path.join(options.output, 'R.yaml')))
T_writer = ecto.If("T writer", cell=MatWriter(filename=os.path.join(options.output, 'T.yaml')))
image_writer = ecto.If(cell=ImageSaver(filename_param=os.path.join(options.output, 'train.png')))

for y, x in (
            (orb['points3d'], points3d_writer),
            (orb['descriptors'], descriptor_writer),
            (orb['points'], points_writer),
            (plane_filter['R'], R_writer),
            (plane_filter['T'], T_writer)
            ):
    plasm.connect(orb_display['save'] >> x['__test__'],
                  y >> x['mat'],
              )
plasm.connect(orb_display['save'] >> image_writer['__test__'],
              source['image'] >> image_writer['image']
              )

run_plasm(options, plasm, locals=vars())
