import React, { Component } from "react";
import Consts from "./Consts";
import Styles from "./Styles";
import HighDiv from "./HighDiv";
import * as tf from "@tensorflow/tfjs";
import Utils from "./Utils";

const marginSide = 3;
const nButtons = 8;
const buttonWidth =
  ((Consts.boardW + 8) * Consts.cellSize + 3 * marginSide) / nButtons -
  marginSide;
const buttonHeight = 30;

function ParseTypeFromModelClass(className) {
  if (className == "Conv2D") return "co";
  if (className == "Dense") return "de";
  if (className == "LeakyReLU") return "ac"; // activation
  if (className == "Add") return "de"; // FIXME
  if (className == "Flatten") return "de"; // FIXME
  if (className == "BatchNormalization") return "ac"; // also take batch normalization
  console.log("ERROR!!!! UNKNOWN CLASS", className);
  return undefined;
}

// NOTE: for blocks need to ignore "layers" param!
function GetTfNodesFromNode(parentNodes, node) {
  let type = node.type;
  if (type == "in") return [tf.layers.input({ ...node.params })];
  if (type == "ou") return null; // for outputs we don't care since its not a real layer

  if (type == "co") {
    // FIXME: need to determine if to have here 2d or 3d
    return [tf.layers.conv2d({ ...node.params }).apply(parentNodes)];
  }
  if (type == "de") {
    return [tf.layers.dense({ ...node.params }).apply(parentNodes)];
  }
  // for block nodes we return an array of nodes
  if (type == "db") {
    let tfNodes = [];
    for (let i = 0; i < node.params.layers; i++) {
      tfNodes.push(
        tf.layers
          .dense({ ...node.params, layers: undefined })
          .apply(i == 0 ? parentNodes : tfNodes[tfNodes.length - 1])
      );
    }
    return tfNodes;
  }
  if (type == "cb") {
    let tfNodes = [];
    for (let i = 0; i < node.params.layers; i++) {
      tfNodes.push(
        tf.layers
          .conv2d({ ...node.params, layers: undefined })
          .apply(i == 0 ? parentNodes : tfNodes[tfNodes.length - 1])
      );
    }
    return tfNodes;
  }
  console.log("ERROR!!!! UNKNOWN TYPE", type);
  return undefined;
}

export default class ModelMenu extends Component {
  constructor(props) {
    super(props);
    this.importRef = React.createRef();
    this.loadRef = React.createRef();
  }

  import() {
    if (this.props.locked) {
      this.props.displayLog(
        "Model is locked. To make changes unlock the model first."
      );
      console.log("model is locked");
      return;
    }
    // this will import tensorflowjs model
    // FIXME: model should be locked when loaded!!
    console.log("load called");
    this.importRef.current.click();
  }

  importModel() {
    if (this.props.locked) {
      this.props.displayLog(
        "Model is locked. To make changes unlock the model first."
      );
      console.log("model is locked");
      return;
    }
    let json = undefined;
    let bins = [];
    for (let file of this.importRef.current.files) {
      if (file.name.endsWith(".json")) json = file;
      else if (file.name.endsWith(".bin")) bins.push(file);
    }

    if (!json) {
      this.props.displayLog("Error. Please select the tfjs model.json file.");
      return;
    }

    if (bins.length == 0) {
      // FIXME: this message gets lost, we need to persist it longer.
      this.props.displayLog(
        "Warning, importing tfjs model without weights. Please select bin files to load weights."
      );
    }

    console.log(json.name, bins.length > 0 ? bins[0].name : "");
    let fileReader = new FileReader();
    fileReader.onload = (fileLoadedEvent) => {
      let res = fileLoadedEvent.target.result;
      let jsonData = JSON.parse(res);
      this.parseJsonModel(jsonData);
    };

    fileReader.readAsText(json, "UTF-8");

    tf.loadLayersModel(tf.io.browserFiles([json, ...bins])).then((res) => {
      //tf.loadGraphModel(tf.io.browserFiles([json, ...bins])).then((res) => {
      console.log(res);
      this.props.onTfModel(res);
    });
  }

  parseJsonModel(jsonModel) {
    //console.log(jsonModel);
    /*
      modelTopology -> model_config -> config -> [input_layers: [], output_layers: [], layers: []]
    */
    let model = { nodes: {}, edges: [] };
    let config = jsonModel.modelTopology.model_config.config;
    let inputs = {};
    let outputs = {};
    let nameMap = {};
    for (let input of config.input_layers) inputs[input[0]] = true;
    for (let output of config.output_layers) outputs[output[0]] = true;
    for (let layer of config.layers) {
      let type = undefined;
      let isInput = !!inputs[layer.name];
      if (isInput) {
        type = "in";
      }
      if (!type) type = ParseTypeFromModelClass(layer.class_name);
      if (!type) {
        console.log("warning: unknown class", layer.class_name);
        continue;
      }

      if (type == "ac") {
        // activation type can have at most one parent, and we do not explicitely create
        // ac layer, but rather add "ac type" into the properties of the parent layer
        if (
          !layer.inbound_nodes ||
          layer.inbound_nodes.length != 1 ||
          layer.inbound_nodes[0].length != 1 ||
          layer.inbound_nodes[0][0].length == 0
        ) {
          console.log("warning, strange activation node inputs", layer);
        } else {
          // FIXME: what if nodes are ordered wrong?
          nameMap[layer.name] = nameMap[layer.inbound_nodes[0][0][0]];
        }
        continue;
      }

      nameMap[layer.name] = layer.name;

      // TODO: read input_shape
      // FIXME: parse node params!!!!!
      // TODO: merge layers into blocks!!!
      // TODO: when merging make sure that the ndoe does not have any fanouts???
      model.nodes[layer.name] = type;
      if (!isInput) {
        if (!layer.inbound_nodes) {
          console.log("warning, strange node inputs", layer);
        } else {
          for (let inNode of layer.inbound_nodes[0]) {
            //console.log(inNode);
            model.edges.push([nameMap[inNode[0]], layer.name]);
          }
        }
      }
      // NOTE: since output nodes are not explicitely present in the graph
      // we have to add them
      if (outputs[layer.name]) {
        let newName = "output_" + layer.name;
        model.nodes[newName] = "ou";
        model.edges.push([nameMap[layer.name], newName]);
      }
    }

    // there is no concept of conv/dense blocks in keras, but in our visual application
    // blocks make so much more sense.

    this.props.loadModel(model);
  }

  save() {
    // NOTE: there is no need to check for lock here
    // FIXME: maybe also save the positions in ui?
    let model = {
      edges: Object.keys(this.props.graph.edges).map((key) => {
        return [
          this.props.graph.edges[key].start,
          this.props.graph.edges[key].end,
        ];
      }),
    };
    let nodes = {};
    Object.keys(this.props.graph.nodes).map((key) => {
      nodes[key] = this.props.graph.nodes[key].type;
    });
    model.nodes = nodes;
    let params = {};
    Object.keys(this.props.graph.nodes).map((key) => {
      params[key] = this.props.graph.nodes[key].params;
    });
    model.params = params;
    model.ui = { text: this.props.name };
    let modelJson = JSON.stringify(model);
    let fileName = this.props.name + "_template.json";
    // downloading
    let a = document.createElement("a");
    let file = new Blob([modelJson], { type: "text/plain" });
    a.href = URL.createObjectURL(file);
    a.download = fileName;
    a.click();
    this.props.displayLog(
      "Model template is saved to the file in downloads: " + fileName + "."
    );
  }

  load() {
    // this will load the model in our format
    if (this.props.locked) {
      this.props.displayLog(
        "Model is locked. To make changes unlock the model first."
      );
      console.log("model is locked");
      return;
    }
    // this will import template
    console.log("load called");
    this.loadRef.current.click();
  }

  loadModel() {
    let json = undefined;
    console.log(this.loadRef.current);
    for (let file of this.loadRef.current.files) {
      if (file.name.endsWith(".json")) json = file;
    }
    console.log(json.name);
    let fileReader = new FileReader();
    fileReader.onload = (fileLoadedEvent) => {
      let res = fileLoadedEvent.target.result;
      let jsonData = JSON.parse(res);
      this.parseJsonTemplate(jsonData);
    };
    fileReader.readAsText(json, "UTF-8");
  }

  parseJsonTemplate(jsonTemplate) {
    console.log(jsonTemplate);
    this.props.loadModel(jsonTemplate);
  }

  fitGame() {
    if (this.props.locked) {
      this.props.displayLog(
        "Model is locked. To make changes unlock the model first."
      );
      console.log("model is locked");
      return;
    }
    // FIXME: need to check for lock here!
    // this function will set the dimentions of the inputs and the outputs to
    // fit the current game.
    // model has to be rebuilt afterwards
    // FIXME: for now only support a single input and 2 outputs
    if (
      !this.props.graph.inputs ||
      Object.keys(this.props.graph.inputs).length != 1 ||
      !this.props.graph.outputs ||
      Object.keys(this.props.graph.outputs).length != 2
    ) {
      this.props.displayLog(
        "Model should contain 1 input and 2 outputs. Fitting aborted."
      );
      console.log("bad number of inputs and outputs");
      return;
    }
    let input = Object.keys(this.props.graph.inputs)[0];
    let output1 = Object.keys(this.props.graph.outputs)[0];
    let output2 = Object.keys(this.props.graph.outputs)[1];
    //console.log(this.props.graph.inputs, this.props.graph.outputs);
    this.props.graph.onNodeUpdate(input, { shape: this.props.gameInputDims });
    this.props.graph.onNodeUpdate(output1, { units: 1 });
    // FIXME: we are preparing a model for the currently active player
    this.props.graph.onNodeUpdate(output2, { units: this.props.gameNMoves });
    this.props.displayLog(
      "Model has been fit to the current game/player pair."
    );
  }

  build() {
    if (this.props.locked) {
      this.props.displayLog(
        "Model is locked. To make changes unlock the model first."
      );
      console.log("model is locked");
      return;
    }
    // FIXME: maybe there is a better place for this?
    let tfNodes = [];
    let inputs = [];
    let outputs = [];
    let mapp = {};

    this.props.graph.traverseFromInputs((id, node) => {
      //console.log(id, JSON.stringify(node));
      let parents =
        node.type == "in"
          ? []
          : Object.keys(node.in).map((edge) => {
              return Utils.getStartEndFromEdgeHash(edge)[0];
            });
      let parentNodes = parents.map((id) => {
        return tfNodes[mapp[id]];
      });

      let converted = GetTfNodesFromNode(parentNodes, node);
      if (node.type == "in") {
        inputs.push(converted[0]);
        tfNodes.push(converted[0]);
        mapp[id] = tfNodes.length - 1;
      } else {
        if (node.type != "ou") {
          for (let el of converted) tfNodes.push(el);
          mapp[id] = tfNodes.length - 1;
        } else {
          // NOTE: for outputs there is no explicit layer in tfjs,
          // so we apply its parent as an output
          // FIXME: outputs from the same node will be reduced!!!!!
          // need to figure out a way to keep them if we want...
          // Although why would we want....
          // NOTE: creating a flatten + dense layer to map prev node to the output shape
          // FIXME: for now only supporting 1-dim outputs
          // FIXME: do not add flatten layer if input is already 1dim!!!
          if (parents.length != 1) {
            // FIXME: add automatic concat layer when needed!
            console.log(
              "ERROR!! number of parents coming in the output node is not 1"
            );
          }
          if (
            // FIXME: inShape is not being populated yet
            node.info.inShape.length == 1 &&
            this.props.graph.nodes[parents[0]].info.inShape[0] ==
              node.params.units
          ) {
            // If node is already in an appropriate format we don't need any layers
            outputs.push(tfNodes[mapp[parents[0]]]);
          } else {
            let tfNodeFlatten = tfNodes[mapp[parents[0]]];
            if (node.info.inShape.length != 1) {
              tfNodeFlatten = tf.layers.flatten().apply(tfNodeFlatten);
            }
            //console.log(node);
            // FIXME: activation here?
            let tfNodeDense = tf.layers
              .dense({
                units: node.params.units,
                activation: node.params.activation,
              })
              .apply(tfNodeFlatten);
            outputs.push(tfNodeDense);
          }
        }
      }
    });

    let tfmodel = tf.model({ inputs, outputs });
    //  console.log(tfmodel);
    // model.add(tf.layers
    //iterate through this.props.graph.nodes; and build
    // build model from currently loaded model
    // NOTE: importModel should automatically build its own model
    tfmodel.name = this.props.name;
    this.props.onTfModel(tfmodel);
  }

  export() {
    // this will export model into tensorflow js format
    // NOTE: there is no need to check for lock here
    this.props.exportActiveTfModel();
  }

  lock(locked) {
    this.props.setLocked(locked);
  }

  render() {
    return (
      <div style={Styles.dummyStyle}>
        <div
          style={{
            position: "absolute",
            top: 0,
            left: "50%",
            width: buttonWidth * nButtons + (nButtons + 1) * marginSide,
            transform: "translate(-50%)",
            //backgroundColor: "rgb(125, 125, 125)",
          }}
        >
          <input
            multiple
            ref={this.importRef}
            type="file"
            style={{ display: "none" }}
            onChange={() => {
              this.importModel();
            }}
          />
          <input
            ref={this.loadRef}
            type="file"
            style={{ display: "none" }}
            onChange={() => {
              this.loadModel();
            }}
          />
          <div
            style={{
              marginTop: marginSide,
              marginBottom: marginSide,
              marginLeft: marginSide,
              marginRight: marginSide,
            }}
          >
            <div
              style={{
                position: "relative",
                width: "100%",
                display: "flex",
                justifyContent: "space-between",
                //backgroundColor: "rgb(125, 125, 125)",
              }}
            >
              <HighDiv
                text="IMPORT"
                style={{
                  width: buttonWidth,
                  height: buttonHeight,
                  backgroundColor: "rgb(200, 200, 200)",
                  display: "flex",
                  justifyContent: "center",
                  alignItems: "center",
                  cursor: "grab",
                }}
                clickable={true}
                onClick={() => this.import()}
              />
              <HighDiv
                text="EXPORT"
                style={{
                  width: buttonWidth,
                  height: buttonHeight,
                  backgroundColor: "rgb(200, 200, 200)",
                  display: "flex",
                  justifyContent: "center",
                  alignItems: "center",
                  cursor: "grab",
                }}
                clickable={true}
                onClick={() => this.export()}
              />
              <HighDiv
                text="SAVE"
                style={{
                  width: buttonWidth,
                  height: buttonHeight,
                  backgroundColor: "rgb(200, 200, 200)",
                  display: "flex",
                  justifyContent: "center",
                  alignItems: "center",
                  cursor: "grab",
                }}
                clickable={true}
                onClick={() => this.save()}
              />
              <HighDiv
                text="LOAD"
                style={{
                  width: buttonWidth,
                  height: buttonHeight,
                  backgroundColor: "rgb(200, 200, 200)",
                  display: "flex",
                  justifyContent: "center",
                  alignItems: "center",
                  cursor: "grab",
                }}
                clickable={true}
                onClick={() => this.load()}
              />
              <HighDiv
                text="BUILD"
                style={{
                  width: buttonWidth,
                  height: buttonHeight,
                  backgroundColor: "rgb(200, 200, 200)",
                  display: "flex",
                  justifyContent: "center",
                  alignItems: "center",
                  cursor: "grab",
                }}
                clickable={true}
                onClick={() => this.build()}
              />
              <HighDiv
                text="FIT"
                style={{
                  width: buttonWidth,
                  height: buttonHeight,
                  backgroundColor: "rgb(200, 200, 200)",
                  display: "flex",
                  justifyContent: "center",
                  alignItems: "center",
                  cursor: "grab",
                }}
                clickable={true}
                onClick={() => this.fitGame()}
              />
              <HighDiv
                text={this.props.locked ? "UNLOCK" : "LOCK"}
                style={{
                  width: buttonWidth,
                  height: buttonHeight,
                  backgroundColor: "rgb(200, 200, 200)",
                  display: "flex",
                  justifyContent: "center",
                  alignItems: "center",
                  cursor: "grab",
                }}
                clickable={true}
                onClick={() => this.lock(!this.props.locked)}
              />
              <HighDiv
                text={"PLACE"}
                style={{
                  width: buttonWidth,
                  height: buttonHeight,
                  backgroundColor: "rgb(200, 200, 200)",
                  display: "flex",
                  justifyContent: "center",
                  alignItems: "center",
                  cursor: "grab",
                }}
                clickable={true}
                onClick={() => {
                  if (this.props.locked) {
                    this.props.displayLog(
                      "Model is locked. To make changes unlock the model first."
                    );
                    return;
                  }
                  this.props.autoPlace();
                }}
              />
            </div>
          </div>
        </div>
      </div>
    );
  }
}
