// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "softmax_inst.h"
#include "primitive_type_base.h"
#include "json_object.h"
#include <string>

namespace cldnn {
primitive_type_id softmax::type_id() {
    static primitive_type_base<softmax> instance;
    return &instance;
}

layout softmax_inst::calc_output_layout(softmax_node const& node) {
    assert(static_cast<bool>(node.get_primitive()->output_data_type) == false &&
           "Output data type forcing is not supported for softmax_node!");

    auto output_layout = node.input().get_output_layout();

    if (node.has_fused_primitives())
        output_layout.data_type = node.get_fused_output_layout().data_type;

    return output_layout;
}

std::string softmax_inst::to_string(softmax_node const& node) {
    auto desc = node.get_primitive();
    auto node_info = node.desc_to_json();

    std::stringstream primitive_description;

    node_info->dump(primitive_description);

    return primitive_description.str();
}

softmax_inst::typed_primitive_inst(network& network, softmax_node const& node) : parent(network, node) {
    //    auto& input_offset  = arg.input_offset;
    //    auto& output_offset = arg.output_offset;
    //    auto& output_size   = arg.output_size;
    //
    //    auto& input_inst  = arg.input[0].primitive().as<const memory&>().argument;
    //    auto& output_inst = arg.output[0].as<const memory&>().argument;
    //    for (auto &x : input_offset.raw) if (x < 0) throw std::runtime_error("Softmax negative input offset.");
    //
    //    for(size_t i = 0; i < input_inst.size.raw.size(); ++i) {
    //        if( input_inst.size.raw[i] < output_size.raw[i] +  input_offset.raw[i]) throw std::runtime_error("Softmax
    //        input/output size does not match."); if(output_inst.size.raw[i] < output_size.raw[i] +
    //        output_offset.raw[i]) throw std::runtime_error("Softmax sizes too small.");
    //    }

    // auto& input_inst = network.get_topology()->get_primitives().at(desc->input()[0]);
    // if (input_inst->output_layout->size.format == cldnn::format::bfyx)
    //    if (input_inst->output_layout->size.spatial[0] != 1 || input_inst->output_layout->size.spatial[1] != 1)
    //        throw std::runtime_error("Softmax input has more than one dimension per batch");
}
}  // namespace cldnn
