#!/usr/bin/env python
from ecto_image_pipeline.base import RescaledRegisteredDepth
from ecto_opencv import calib, highgui, imgproc
from object_recognition_core.db import Documents, ObjectDb, tools as dbtools, models
from object_recognition_core.db.cells import ObservationReader
from object_recognition_reconstruction import MatToPointCloudXYZRGB, PointCloudTransform, PointCloudAccumulator
from tempfile import NamedTemporaryFile
import argparse
import couchdb
import ecto
import ecto_image_pipeline.base
import ecto_pcl
import object_recognition_core
import os
import shutil
import subprocess
import sys
import tempfile

cloud_view = ecto_pcl.CloudViewer("Cloud Display", window_name="PCD Viewer")[:]
imshows = dict(image=highgui.imshow('Image Display', name='image')[:], mask=highgui.imshow('Mask Display', name='mask')[:],
               depth=highgui.imshow('Depth Display', name='depth')[:])
def upload_mesh(db, session, cloud_file, mesh_file=None):
    r = models.find_model_for_object(db, session.object_id, 'mesh')
    m = None
    for model in r:
        m = models.Model.load(db, model)
        print "updating model:", model
        break
    if not m:
        m = models.Model(object_id=session.object_id, method='mesh')
        print "creating new model."
    m.store(db)
    with open(cloud_file, 'r') as mesh:
        db.put_attachment(m, mesh, filename='cloud.ply')
    if mesh_file:
        with open(mesh_file, 'r') as mesh:
            db.put_attachment(m, mesh, filename='mesh.stl', content_type='application/octet-stream')

FILTER_SCRIPT= '''
<!DOCTYPE FilterScript>
<FilterScript>
 <filter name="Compute normals for point sets">
  <Param type="RichInt" value="10" name="K"/>
  <Param type="RichBool" value="false" name="flipFlag"/>
  <Param x="0" y="0" z="0" type="RichPoint3f" name="viewPos"/>
 </filter>
 <filter name="Smooths normals on a point sets">
  <Param type="RichInt" value="10" name="K"/>
  <Param type="RichBool" value="false" name="useDist"/>
 </filter>
 <filter name="Surface Reconstruction: Poisson">
  <Param type="RichInt" value="6" name="OctDepth"/>
  <Param type="RichInt" value="6" name="SolverDivide"/>
  <Param type="RichFloat" value="1" name="SamplesPerNode"/>
  <Param type="RichFloat" value="1" name="Offset"/>
 </filter>
</FilterScript>
'''
FILTER_SCRIPT_PIVOTING='''
<!DOCTYPE FilterScript>
<FilterScript>
 <filter name="Surface Reconstruction: Ball Pivoting">
  <Param type="RichAbsPerc" value="0.01" min="0" name="BallRadius" max="0.122839"/>
  <Param type="RichFloat" value="20" name="Clustering"/>
  <Param type="RichFloat" value="90" name="CreaseThr"/>
  <Param type="RichBool" value="false" name="DeleteFaces"/>
 </filter>
 <filter name="Laplacian Smooth">
  <Param type="RichInt" value="1" name="stepSmoothNum"/>
  <Param type="RichBool" value="true" name="Boundary"/>
  <Param type="RichBool" value="true" name="cotangentWeight"/>
  <Param type="RichBool" value="false" name="Selected"/>
 </filter>
</FilterScript>
'''
#meshlabserver -i cloud_44ed68c2b66cc8aefc7df45fd63c4ac8_00000.ply -o mug.stl -s meshlab.xml.mlx

def meshlab(filename_in, filename_out):
    import subprocess
    f = NamedTemporaryFile(delete=False)
    script = f.name
    f.write(FILTER_SCRIPT_PIVOTING)
    f.close()
    env = os.environ
    env['LC_NUMERIC'] = 'C'
    subprocess.Popen('meshlabserver -i ' + filename_in + ' -o ' + filename_out + ' -s ' + script,
                     env=env, cwd=os.getcwd(), shell=True).wait()
    os.unlink(script)

def simple_mesh_session(dbs, session, args):
    db_params = dbtools.args_to_db_params(args)
    obs_ids = models.find_all_observations_for_session(db_params, session.id)
    if len(obs_ids) == 0:
        raise RuntimeError("There are no observations available.")
    db_reader = ObservationReader('Database Source')

    #observation dealer will deal out each observation id.
    observation_dealer = ecto.Dealer(tendril=db_reader.inputs.at('document'), iterable=Documents(ObjectDb(db_params), list(obs_ids)))
    depthTo3d = calib.DepthTo3d('Depth ~> 3D')
    erode = imgproc.Erode('Mask Erosion', kernel=3) #-> 7x7
    rescale_depth = RescaledRegisteredDepth('Depth scaling') #this is for SXGA mode scale handling.
    point_cloud_transform = PointCloudTransform('Object Space Transform')
    point_cloud_converter = MatToPointCloudXYZRGB('To Point Cloud')
    pre_voxel = ecto_pcl.VoxelGrid("Pre Voxel Decimation", leaf_size=0.005)
    to_ecto_pcl = ecto_pcl.PointCloudT2PointCloud('converter', format=ecto_pcl.XYZRGB)
    K_converter = calib.KConverter()
    plasm = ecto.Plasm()
    plasm.connect(
      observation_dealer[:] >> db_reader['document'],
      db_reader['depth', 'image', 'mask'] >> rescale_depth['depth', 'image', 'mask'],
      db_reader['K'] >> depthTo3d['K'],
      rescale_depth['depth'] >> depthTo3d['depth'],
      depthTo3d['points3d'] >> point_cloud_converter['points'],
      db_reader['image'] >> point_cloud_converter['image'],
      erode['image'] >> point_cloud_converter['mask'],
      rescale_depth['mask'] >> erode['image'],
      db_reader['R', 'T'] >> point_cloud_transform['R', 'T'],
      point_cloud_converter['point_cloud'] >> to_ecto_pcl[:],
      to_ecto_pcl[:] >> pre_voxel['input'],
      pre_voxel['output'] >> point_cloud_transform['cloud']
    )

    accum = PointCloudAccumulator('accumulator')
    voxel_grid = ecto_pcl.VoxelGrid("voxel_grid", leaf_size=0.005)
    outlier_removal = ecto_pcl.StatisticalOutlierRemoval('Outlier Removal', mean_k=2, stddev=2)
    source, sink = ecto.EntangledPair(value=accum.inputs.at('view'),
                                      source_name='Feedback Cloud',
                                      sink_name='Feedback Cloud')
    ply_writer = ecto.If('PlyWriter', cell=ecto_pcl.PLYWriter('PLY Saver',
                                      filename_format='cloud_%s_%%05d.ply' % str(session.id)))
    mls = ecto_pcl.MovingLeastSquares('MLS', search_radius=0.0075)
    ply_writer.inputs.__test__ = False
    plasm.connect(source[:] >> accum['previous'],
                  point_cloud_transform['view'] >> accum['view'],
                  accum[:] >> voxel_grid[:],
                  voxel_grid[:] >> outlier_removal[:],
                  outlier_removal[:] >> (mls[:], sink[:]),
                  mls[:] >> ply_writer['input'],
    )

    if args.visualize:
        global cloud_view
        plasm.connect(
          mls[:] >> cloud_view,
          db_reader['image'] >> imshows['image'],
          db_reader['depth'] >> imshows['depth'],
          erode['image'] >> imshows['mask'],
          )

    from ecto.opts import run_plasm
    run_plasm(args, plasm, locals=vars())

    #finally write the ply to disk.
    ply_writer.inputs.__test__ = True
    ply_writer.process()
    dir_path = tempfile.mkdtemp()
    ply_name = os.path.join(dir_path, 'cloud_%s_%05d.ply' % (str(session.id), 0))
    ply_name = 'cloud_%s_%05d.ply' % (str(session.id), 0)
    mesh_name = os.path.join(dir_path, 'cloud_%s.stl' % str(session.id))
    meshlab(ply_name, mesh_name)
    if args.commit:
        upload_mesh(dbs, session, ply_name, mesh_name)
        shutil.rmtree(dir_path)

###################################################################################################################
def parse_args():
    parser = argparse.ArgumentParser(description='Computes a surface mesh of an object in the database')
    parser.add_argument('-s', '--session_id', metavar='SESSION_ID', dest='session_id', type=str, default='',
                       help='The session id to reconstruct.')
    parser.add_argument('--all', dest='compute_all', action='store_const',
                        const=True, default=False,
                        help='Compute meshes for all possible sessions.')
    parser.add_argument('--visualize', dest='visualize', action='store_const',
                        const=True, default=False,
                        help='Turn on visualization')
    dbtools.add_db_arguments(parser)

    sched_group = parser.add_argument_group('Scheduler Options')
    from ecto.opts import scheduler_options
    scheduler_options(sched_group)

    args = parser.parse_args()
    if args.compute_all == False and args.session_id == '':
        parser.print_usage()
        print "You must either supply a session id, or --all."
        sys.exit(1)
    return args

if "__main__" == __name__:
    args = parse_args()
    couch = couchdb.Server(args.db_root)
    dbs = dbtools.init_object_databases(couch)
    sessions = dbs
    if args.compute_all:
        models.sync_models(dbs)
        results = models.Session.all(sessions)
        for session in results:
            simple_mesh_session(dbs, session, args)
    else:
        session = models.Session.load(sessions, args.session_id)
        if session == None or session.id == None:
            print "Could not load session with id:", args.session_id
            sys.exit(1)
        simple_mesh_session(dbs, session, args)
    # do it again to create the view for meshes
    dbs = dbtools.init_object_databases(couch)
