/*******************************************************************************
 * Copyright (c) 2019 Nerian Vision GmbH
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *******************************************************************************/

#include <iostream>
#include <fstream>
#include <stdexcept>
#include <cstring>
#include "visiontransfer/imagepair.h"

#ifdef _WIN32
#include <winsock2.h>
#else
#include <arpa/inet.h>
#endif

using namespace visiontransfer;

namespace visiontransfer {

ImagePair::ImagePair()
    : width(0), height(0), qMatrix(NULL), timeSec(0), timeMicrosec(0),
        seqNum(0), minDisparity(0), maxDisparity(0), subpixelFactor(16), disparityPair(false),
        referenceCounter(NULL) {
    formats[0] = FORMAT_8_BIT_MONO;
    formats[1] = FORMAT_8_BIT_MONO;
    data[0] = NULL;
    data[1] = NULL;
    rowStride[0] = 0;
    rowStride[1] = 1;
}

ImagePair::ImagePair(const ImagePair& other) {
    copyData(*this, other, true);
}

ImagePair& ImagePair::operator= (ImagePair const& other) {
    if(&other != this) {
        decrementReference();
        copyData(*this, other, true);
    }
    return *this;
}

ImagePair::~ImagePair() {
    decrementReference();
}

void ImagePair::copyData(ImagePair& dest, const ImagePair& src, bool countRef) {
    dest.width = src.width;
    dest.height = src.height;

    for(int i=0; i<2; i++) {
        dest.rowStride[i] = src.rowStride[i];
        dest.formats[i] = src.formats[i];
        dest.data[i] = src.data[i];
    }

    dest.qMatrix = src.qMatrix;
    dest.timeSec = src.timeSec;
    dest.timeMicrosec = src.timeMicrosec;
    dest.seqNum = src.seqNum;
    dest.minDisparity = src.minDisparity;
    dest.maxDisparity = src.maxDisparity;
    dest.subpixelFactor = src.subpixelFactor;
    dest.disparityPair = src.disparityPair;
    dest.referenceCounter = src.referenceCounter;

    if(dest.referenceCounter != nullptr && countRef) {
        (*dest.referenceCounter)++;
    }
}

void ImagePair::decrementReference() {
    if(referenceCounter != nullptr && --(*referenceCounter) == 0) {
        delete []data[0];
        delete []data[1];
        delete []qMatrix;
        delete referenceCounter;

        data[0] = nullptr;
        data[1] = nullptr;
        qMatrix = nullptr;
        referenceCounter = nullptr;
    }
}

void ImagePair::writePgmFile(int imageNumber, const char* fileName) const {
    if(imageNumber < 0 || imageNumber >1) {
        throw std::runtime_error("Illegal image number!");
    }

    std::fstream strm(fileName, std::ios::out | std::ios::binary);

    // Write PGM / PBM header
    int type, maxVal, bytesPerChannel, channels;
    switch(formats[imageNumber]) {
        case FORMAT_8_BIT_MONO:
            type = 5;
            maxVal = 255;
            bytesPerChannel = 1;
            channels = 1;
            break;
        case FORMAT_12_BIT_MONO:
            type = 5;
            maxVal = 4095;
            bytesPerChannel = 2;
            channels = 1;
            break;
        case FORMAT_8_BIT_RGB:
            type = 6;
            maxVal = 255;
            bytesPerChannel = 1;
            channels = 3;
            break;
        default:
            throw std::runtime_error("Illegal pixel format!");
    }

    strm << "P" << type << " " << width << " " << height << " " << maxVal << std::endl;

    // Write image data
    for(int y = 0; y < height; y++) {
        for(int x = 0; x < width*channels; x++) {
            unsigned char* pixel = &data[imageNumber][y*rowStride[imageNumber] + x*bytesPerChannel];
            if(bytesPerChannel == 2) {
                // Swap endianess
                unsigned short swapped = htons(*reinterpret_cast<unsigned short*>(pixel));
                strm.write(reinterpret_cast<char*>(&swapped), sizeof(swapped));
            } else {
                strm.write(reinterpret_cast<char*>(pixel), 1);
            }
        }
    }
}

void ImagePair::copyTo(ImagePair& dest) {
    dest.decrementReference();
    copyData(dest, *this, false);

    dest.qMatrix = new float[16];
    memcpy(const_cast<float*>(dest.qMatrix), qMatrix, sizeof(float)*16);

    for(int i=0; i<2; i++) {
        int bytesPixel = getBytesPerPixel(i);

        dest.rowStride[i] = width*bytesPixel;
        dest.data[i] = new unsigned char[height*dest.rowStride[i]];

        // Convert possibly different row strides
        for(int y = 0; y < height; y++) {
            memcpy(&dest.data[i][y*dest.rowStride[i]], &data[i][y*rowStride[i]],
                dest.rowStride[i]);
        }
    }

    dest.referenceCounter = new int;
    (*dest.referenceCounter) = 1;
}

int ImagePair::getBytesPerPixel(ImageFormat format) {
    switch(format) {
        case FORMAT_8_BIT_MONO: return 1;
        case FORMAT_8_BIT_RGB: return 3;
        case FORMAT_12_BIT_MONO: return 2;
        default: throw std::runtime_error("Invalid image format!");
    }
}

} // namespace

