/*******************************************************************************
 *  ORFaceLearningModule.cpp
 *
 *  (C) 2006 AG Aktives Sehen <agas@uni-koblenz.de>
 *           Universitaet Koblenz-Landau
 *
 * $Id: $
 *******************************************************************************/

#include <or_nodes/Modules/ORFaceLearningModule.h>

#include <robbie_architecture/Architecture/Config/Config.h>
#include <or_libs/ObjectRecognition/ObjectProperties.h>

#include <sstream>
#include <fstream>
#include <cmath>
#include <cv_bridge/cv_bridge.h>
#include <ros/package.h>
#include <ptu/SetPanTilt.h>


ORFaceLearningModule::ORFaceLearningModule(ros::NodeHandle *nh, ORMatchingModule* objRecMatchingModule, ORLoaderModule* objRecLoaderModule)
{
    m_ORMatchingModule = objRecMatchingModule;
    m_ORLoaderModule = objRecLoaderModule;

    // subscriber
    m_LearnFaceSub = nh->subscribe("/or/learn_face", 1, &ORFaceLearningModule::learnFaceCallback, this);
    m_DetectedFacesSub = nh->subscribe("/face_detection/start_detection", 1, &ORFaceLearningModule::faceDetectionCallback, this);

    // publisher
    m_DetectFacesPub = nh->advertise<std_msgs::Bool>("/face_detection/detect_faces", 1);
    m_FaceLearningFinishedPub = nh->advertise<std_msgs::Bool>("/or/learn_face_finished", 1);

    m_ptu_pub = nh->advertise<ptu::SetPanTilt>("/ptu/set_pan_tilt",1);

    init();
}

void ORFaceLearningModule::init()
{
    ADD_MACHINE_STATE ( m_ModuleMachine, IDLE );
    ADD_MACHINE_STATE ( m_ModuleMachine, GRABBING_IMAGES );

    m_ModuleMachine.setName ( "ORFaceLearningModule State" );
    m_ModuleMachine.setState ( IDLE );

    m_ObjectProperties = 0;
    m_FaceAdditionalBorder = Config::getFloat("ObjectLearning.fFaceAdditionalBorder");

    m_DetectionAttempts = 0;
}


void ORFaceLearningModule::learnFaceCallback(std_msgs::String::ConstPtr msg)
{
    ROS_INFO_STREAM ( "Learn Faces message obtained." );

    if ( m_ModuleMachine.state() == IDLE )
    {
        ptu::SetPanTilt pan_msg;
        pan_msg.panAngle = 0;
        pan_msg.tiltAngle = -0.2;
        pan_msg.absolute = true;
        m_ptu_pub.publish(pan_msg);
        ros::Duration(2).sleep();
        m_DetectionAttempts = 0;
        m_Faces.clear();
        m_ModuleMachine.setState ( GRABBING_IMAGES );
        m_ImageCount = 1; // TODO param should be in message:  msg->getImageCount;
        m_FilterFaces = true; // TODO get rid of this, should always be true

        m_ObjectProperties = new ObjectProperties( msg->data );
        m_ObjectProperties->setType( "FACE" );

        //If faces should be filtered, tell ImageFaceDetectionModule to detect face in image
        getImage();
    }
}

void ORFaceLearningModule::faceDetectionCallback(face_detection::FaceDetectionResult::ConstPtr msg)
{
    if ( m_ModuleMachine.state() == GRABBING_IMAGES )
    {
        ROS_INFO_STREAM( "Received Face Detection Result" );
        m_DetectionAttempts++;

        if( msg->faces.size() <= 0 )
        {
            if(m_DetectionAttempts >= 3)
            {
                std_msgs::Bool finished_msg;
                finished_msg.data = false;
                m_FaceLearningFinishedPub.publish(finished_msg);
                m_ModuleMachine.setState( IDLE );
            }
            else
            {
                // redetect
                ROS_INFO_STREAM("Requesting redetection in 1 second ...");
                sleep(1);
                std_msgs::Bool detect_faces_msg;
                detect_faces_msg.data = false;
                m_DetectFacesPub.publish(detect_faces_msg);
            }
        }
        else
        {
            for( int i=0; i< msg->faces.size(); i++ )
            {
                geometry_msgs::Polygon poly = msg->faces[i].bounding_box;
                int minX = 999; int maxX = 0; int minY = 999; int maxY = 0;
                for( int point = 0; point < poly.points.size(); point++)
                {
                    minX = poly.points.at(point).x < minX ? poly.points.at(point).x : minX;
                    minY = poly.points.at(point).y < minY ? poly.points.at(point).y : minY;
                    maxX = poly.points.at(point).x > maxX ? poly.points.at(point).x : maxX;
                    maxY = poly.points.at(point).y > maxY ? poly.points.at(point).y : maxY;
                }
                Box2D< int > box(minX, minY, maxX, maxY);
                m_Faces.push_back(box);
            }

            cv_bridge::CvImagePtr cv_ptr = cv_bridge::toCvCopy(msg->faces[0].face_image, "mono8"); //TODO chech if this still works

            addImage( &cv_ptr->image, &cv_ptr->image ); // TODO one image shold be enough
        }
    }
}



void ORFaceLearningModule::addImage(cv::Mat *image_gray, cv::Mat *image_color )
{

    ROS_INFO_STREAM("Adding images masked by faces.");
    ImagePropertiesCV* imageProp = makeImageProperties( image_gray, image_color );
    ROS_INFO_STREAM("Successfully added image properties.");

    m_ObjectProperties->addImageProperties( imageProp );
    m_ImageCount--;

    if ( m_ImageCount > 0 )
    {
        usleep( 1000000 );
        getImage();
        m_ImageCount--;
    }
    else
    {
        saveObject();
        ROS_WARN_STREAM("automatically adding face to object list");
        m_ORMatchingModule->addObjectProperties(m_ObjectProperties);
        std::string filename = "auto_";  // TODO temp remove
        filename.append(m_ObjectProperties->getName());  // TODO temp remove
        m_ORLoaderModule->loadObjectProperties(filename); // TODO temp remove
        ROS_WARN_STREAM("loading face to object list");
        m_ObjectProperties = 0;
        std_msgs::Bool finished_msg;
        finished_msg.data = true;
        m_FaceLearningFinishedPub.publish(finished_msg);
        m_ModuleMachine.setState( IDLE );
    }
}

void ORFaceLearningModule::getImage()
{
    ROS_INFO_STREAM( m_ImageCount << " images left." );
    std_msgs::Empty detect_faces_msg;
    m_DetectFacesPub.publish(detect_faces_msg);
}


ImageMaskCV* ORFaceLearningModule::makeFaceMask( cv::Mat *image_gray )
{
    unsigned width= image_gray->cols;
    unsigned height= image_gray->rows;

    ImageMaskCV* mask = new ImageMaskCV ( width, height );
    unsigned char* maskData = mask->getData();

    //If no faces were found, nothing can be learned !!!
    if(m_Faces.size()==0)
    {
        ROS_ERROR_STREAM("Cannot learn face, because no face was detected.");
        return mask;
    }

    //Find biggest face
    int sizeBiggestFace=0;
    std::vector< Box2D<int>  >::iterator biggestFace;

    for(std::vector< Box2D<int>  >::iterator face = m_Faces.begin(); face != m_Faces.end(); ++face)
    {

        ROS_INFO_STREAM( face->minX()<<" "<<face->minY()<<" "<<face->maxX()<<" "<<face->maxY() );
        int width=face->maxX()-face->minX();
        int height=face->maxY()-face->minY();
        int size=width*height;
        ROS_INFO_STREAM("Checking face with size " << size << ".("<<width<<"x"<<height<<")");
        if(size>sizeBiggestFace)
        {
            sizeBiggestFace=size;
            biggestFace = face;
            ROS_INFO_STREAM("New face is bigger with size " << sizeBiggestFace << ". Setting this one.");
        }
    }


    //Enlarge face regions height to make sure enough parts of the person are visible
    int faceHeight=biggestFace->maxY() - biggestFace->minY();
    //check if face is on bottom of the image
    int addHeightValue=faceHeight*Config::getInt( "FaceDetection.iFaceHeightMultiplier" );
    //     if(height-addHeightValue>height)addHeightValue=height+corners->at(2).y();

    ROS_INFO_STREAM("Adding height " << addHeightValue);

    unsigned int minX = biggestFace->minX();// - m_FaceAdditionalBorder * biggestFace->width(); // TODO remove this modifications ???
    unsigned int maxX = biggestFace->maxX();// + m_FaceAdditionalBorder * biggestFace->width();

    unsigned int minY = biggestFace->minY();// - m_FaceAdditionalBorder * biggestFace->height();
    unsigned int maxY = biggestFace->maxY();// + m_FaceAdditionalBorder * biggestFace->height() ;//+ addHeightValue;
    //Iterate over mask

    for ( unsigned x=std::max(0u,minX); x<=std::min(maxX, width-1) ; x++ )
    {
        for ( unsigned y=std::max(0u,minY); y<=std::min(maxY, height-1) ; y++ )
        {
            maskData[y*width+x] = ImageMaskCV::VISIBLE;
        }
    }

    // the following code makes the image wider and heigher
    //  for ( unsigned x=std::max(0u,minX-biggestFace->width()/2); x<=std::min(maxX+biggestFace->width()/2, width-1) ; x++ )
    //  {
    //    for ( unsigned y=std::max(0u,minY+biggestFace->height()); y<=std::min(maxY+biggestFace->height()*2, height-1) ; y++ )
    //    {
    //      maskData[y*width+x] = ImageMaskCV::VISIBLE;
    //    }
    //  }
    return mask;
}

ImageMaskCV* ORFaceLearningModule::makeOvalMask( cv::Mat *image_gray )
{
    unsigned width= image_gray->cols;
    unsigned height= image_gray->rows;

    ImageMaskCV* mask = new ImageMaskCV ( width, height );
    unsigned char* maskData = mask->getData();

    float a = width/Config::getInt( "ObjectLearning.iEllipseA" );//5.0;
    float b = height/Config::getInt( "ObjectLearning.iEllipseB" );//3.0;
    float midX = width/2.0;
    float midY = height/2.0;

    unsigned i=0;
    for ( unsigned y=0; y<height; y++ )
    {
        for ( unsigned x=0; x<width; x++ )
        {
            //apply an oval face-shaped mask to the background image
            if(( (x-midX)*(x-midX)/(a*a) + (y-midY)*(y-midY)/(b*b) ) <= 1)
                //if(( (x-midX)*(x-midX) + (y-midY)*(y-midY) ) <= d)
            {
                maskData[i] = ImageMaskCV::VISIBLE;
            }
            ++i;
        }
    }

    return mask;
}


//TODO: use face detection instead of oval shape
ImagePropertiesCV* ORFaceLearningModule::makeImageProperties( cv::Mat *image_gray, cv::Mat *image_color )
{
    ImageMaskCV* mask;
    if(m_FilterFaces)
    {
        mask = makeFaceMask( image_gray );
    }
    else
    {
        mask = makeOvalMask( image_gray );
    }


    //crop image & mask

    Box2D<int> area = mask->getBoundingBox();
    area.expand( 2 );
    area.clip( Box2D<int>( 0, 0, image_gray->cols, image_gray->rows ) );

    int newWidth = area.width();
    int newHeight= area.height();
    int minX = area.minX();
    int minY = area.minY();

    ImageMaskCV *croppedMask = mask->subMask( area );

    cv::Mat *croppedImageY = new cv::Mat( newHeight, newWidth, CV_8UC1 );
    for ( int y=0; y<newHeight; y++ )
    {
        for ( int x=0; x<newWidth; x++ )
        {
            croppedImageY->at<unsigned char>(y,x) = image_gray->at<unsigned char>(y+minY, x+minX);
        }
    }

    cv::Mat *croppedImageUV = new cv::Mat( newHeight, newWidth, CV_8UC3 );
    for ( int y=0; y<newHeight; y++ )
    {
        for ( int x=0; x<newWidth; x++ )
        {
            unsigned char val = image_gray->at<unsigned char>(y+minY, x+minX);
            croppedImageUV->at<cv::Vec3b>(y,x) = cv::Vec3b(val, val, val);
            //        unsigned char val = image_color->at<unsigned char>(y+minY, x+minX); // TODO this is actually a gray image not color!
            //        croppedImageUV->at<cv::Vec3f>(y,x) = cv::Vec3f(val,val,val);
        }
    }

    cv::imwrite("/home/vseib/Desktop/crop_gray.png", *croppedImageY );

    // TODO
    //ImagePropertiesCV* imageProperties = new ImagePropertiesCV( "", croppedImageY, croppedImageUV, croppedMask );
    ImagePropertiesCV* imageProperties = new ImagePropertiesCV( "", croppedImageY, croppedImageUV, croppedMask );
    imageProperties->calculateProperties();

    delete mask;

    return imageProperties;
}


void ORFaceLearningModule::saveObject ( )
{
    std::string path = ros::package::getPath("or_nodes");
    std::string filename = path + "/objectProperties/auto_" + m_ObjectProperties->getName() + ".objprop";

    ROS_INFO_STREAM("Saving object to: " << filename);

    // write objectProperties file
    std::ofstream out ( filename.c_str() );
    boost::archive::text_oarchive oa(out);

    oa << m_ObjectProperties;

    delete m_ObjectProperties;
    m_ObjectProperties = new ObjectProperties();
    ROS_INFO_STREAM ( "Object saved to " << filename );
}


ORFaceLearningModule::~ORFaceLearningModule()
{
    delete m_ObjectProperties;
}


void ORFaceLearningModule::idleProcess()
{
}


