import React, { Component } from "react";
import InfoField from "../InfoField";
import HighDiv from "../modeling/HighDiv";
import ParallelController from "./ParallelController";
import ToggleSwitch from "./../inference/ToggleSwitch";
import * as tf from "@tensorflow/tfjs";
import "@tensorflow/tfjs-backend-cpu";
//import "@tensorflow/tfjs-backend-wasm";
import PlayerModels from "../inference/PlayerModels";

import { default as Game2048Engine } from "./../games/2048/GameEngine";
import { default as Connect4Engine } from "./../games/connect4/GameEngine";

const totalWidth = 570;
const totalHeight = 1000;

const inferenceProps = {
  nGames: 100,
  nSim: 50,
  randFrac: 1,
  maxTime: 0, // ms
  extraWinningScore: 0,
  evalBatchSize: -1,
  maxTrainDataSize: 1000,
  storeDataFraction: 1,
  epochs: 100,
  trainingBatchSize: 256,
  learningRate: 0.001,
  optimizer: "adam",
  trainingLoops: 1,
};

function GetTime() {
  return new Date().getTime();
}

function Softmax(arr) {
  return arr.map(function (value, index) {
    return (
      Math.exp(value) /
      arr
        .map(function (y /*value*/) {
          return Math.exp(y);
        })
        .reduce(function (a, b) {
          return a + b;
        })
    );
  });
}

function Onehot(arr) {
  let ret = new Array(arr.length).fill(0);
  ret[arr.indexOf(Math.max(...arr))] = 1;
  return ret;
}

// play silent games? games with the UI?
// play game with real player?
// implement parralel play for faster inference?
// add time per move? per game?
// store some games for ui replay?

export default class TrainingUi extends Component {
  constructor(props) {
    super(props);
    this.gameIndex = 0;
    // NOTE: every time new game is selected object is destroyed
    this.engine =
      this.props.selectedGame == "2048"
        ? new Game2048Engine()
        : new Connect4Engine();
    this.resetTrainingData();
  }

  resetTrainingData() {
    this.trainingData = [];
    for (let i = 0; i < this.engine.nPlayers(); i++) {
      this.trainingData.push([]);
    }
  }

  state = {
    infoText: ["Logging is enabled."],
    nGames: 1,
    inferenceProps: { ...inferenceProps },
    tfBackend: "webgl",
    playerTypes: [],
    selectedModels: [],
    trainModel: undefined,
  };

  componentDidMount() {
    let playerTypes = [];
    let selectedModels = [];
    for (let i = 0; i < this.engine.nPlayers(); i++) {
      playerTypes.push("COMPUTER");
      let models = [];
      for (let j = 0; j < this.engine.nPlayers(); j++) {
        models.push("RANDOM");
      }
      selectedModels.push(models);
    }
    this.setState({
      nGames: this.state.inferenceProps.nGames,
      selectedModels,
      playerTypes,
    });
    this.onBackendToggled(this.state.tfBackend);
  }

  onBackendToggled(val) {
    this.setState({ tfBackend: val });
    tf.setBackend(val);
  }

  stop() {
    console.log("trying to stop the games");
    this.controllerProps.enabled = false;
  }

  playBatch(nGames, finishedCb) {
    nGames = nGames || 1;
    this.startTime = GetTime();
    this.activeGames = nGames;
    this.totalMoves = 0;
    this.playing = true;
    this.setState({ playing: true });
    let sanitizedProps = {};
    Object.keys(this.state.inferenceProps).map(
      (key) =>
        (sanitizedProps[key] = parseFloat(this.state.inferenceProps[key]))
    );
    this.controllerProps = {
      engine: this.engine,
      enabled: true,
      models: this.props.modelingRef.current.loadedModels,
      playerTypes: this.state.playerTypes,
      selectedModels: this.state.selectedModels,
      inferenceProps: sanitizedProps,
    };
    this.batchController = new ParallelController(this.controllerProps);
    this.printPeriodicSummary();
    // for every finished game callback is fired
    this.batchController.playBatch(
      nGames,
      (finalState) => {
        this.activeGames--;
        this.totalMoves += finalState.moveId;
        this.printSummary();
        // FIXME: parse stats and scores;
        if (this.activeGames == 0) {
          this.playing = false;
          this.setState({ playing: false });
          finishedCb();
        }
      },
      (playerId, inputData, simData) => {
        this.onTrainingData(playerId, inputData, simData);
      }
    );
  }

  onTrainingData(playerId, inputData, simData) {
    // NOTE: state is cloned here so we can use it directly
    // simdata is uniquely built so also can be jsut used directly
    // randomly store into storage
    if (
      Math.random() < parseFloat(this.state.inferenceProps.storeDataFraction) &&
      this.trainingData[playerId].length <
        this.state.inferenceProps.maxTrainDataSize
    ) {
      this.trainingData[playerId].push({ inputData, simData });
    }
  }

  checkModel() {
    if (
      !this.state.trainModel ||
      !this.props.modelingRef.current.loadedModels[this.state.trainModel]
    ) {
      this.addToLog("Model is unavailable.");
      console.log("model unavailable", this.state.trainModel);
      return false;
    }

    // FIXME: need to add lock per model??
    if (this.props.modelingRef.current.state.locked) {
      this.addToLog("Model is locked. Unlock it in the modeling tab first.");
      console.log("Model is locked");
      return false;
    }

    return true;
  }

  train() {
    console.log("train called");
    if (!this.checkModel) return;
    this.trainAsync(() => {});
  }

  loop() {
    if (!this.checkModel) return;
    console.log("loop starting");
    this.trainLoopRec(0, parseFloat(this.state.inferenceProps.trainingLoops));
  }

  trainLoopRec(current, total) {
    if (current >= total) {
      this.addToLog("Finished training.");
      return;
    }
    let loopCb = () => this.trainLoopRec(current + 1, total);
    this.addToLog("   ------- LOOP " + current + " -------");
    this.resetTrainingData();
    let playFinishedCb = () => this.trainAsync(loopCb);
    // NOTE: calling async to avoid nesting
    setTimeout(() => this.playBatch(this.state.nGames, playFinishedCb), 5);
  }

  trainAsync(cb) {
    let model = this.props.modelingRef.current.loadedModels[
      this.state.trainModel
    ];
    let inputs = [];
    let values = [];
    let policies = [];

    // iterate through the data and filter items that belong to the currently selected model
    for (let player = 0; player < this.state.playerTypes.length; player++) {
      if (this.state.selectedModels[player][player] != this.state.trainModel)
        continue;
      if (this.trainingData[player].length == 0) continue;
      for (let record of this.trainingData[player]) {
        inputs.push(record.inputData.input);
        let policy = [];
        for (let i = 0; i < this.engine.nMoves(player); i++) {
          policy.push(
            record.simData.childrenN[i] == undefined
              ? 0
              : record.simData.childrenN[i]
          );
        }
        // FIXME: overflow in softmax for large numbers > 1000?
        //policy = Softmax(policy);
        // FIXME: onehot is needed for crossentropy?
        policy = Onehot(policy);
        policies.push(policy);
        values.push(record.simData.qs[player]);
      }
    }
    //console.log(JSON.stringify(inputs));
    //console.log(JSON.stringify(values));
    //console.log(JSON.stringify(policies));

    let inputTensor = tf.tensor4d(inputs);
    let valueTensor = tf.tensor1d(values);
    let policyTensor = tf.tensor2d(policies);

    // FIXME: check if model has already been compiled
    model.compile({
      optimizer:
        this.state.inferenceProps.optimizer == "adam"
          ? tf.train.adam(parseFloat(this.state.inferenceProps.learningRate))
          : tf.train.sgd(parseFloat(this.state.inferenceProps.learningRate)),
      loss: [
        "meanSquaredError" /* for value */,
        "categoricalCrossentropy" /* for policy*/, // NOTE: crossEntropy requires 1-hot labels
      ],
      metrics: ["accuracy"],
    });
    model
      .fit(inputTensor, [valueTensor, policyTensor], {
        epochs: parseFloat(this.state.inferenceProps.epochs),
        batchSize: parseFloat(this.state.inferenceProps.trainingBatchSize),
        shuffle: true,
        //yieldEvery: "auto", // useful to free main thread
        //validationSplit: 0.1,
        //classWeight: [],
        //sampleWeight: same as input to increase weight of certain inputs (perhaps losing positions)
        callbacks: {
          onEpochEnd: (epoch, logs) => {
            this.addToLog("Loss after epoch " + epoch + " is " + logs.loss);
            console.log("Loss after epoch", epoch, "is", logs.loss);
          },
        },
      })
      .then((info) => {
        console.log("Final loss", info.history.loss);
        cb();
      });
    // will reset training data on PLAY button press
    //this.resetTrainingData();
  }

  addToLog(item) {
    //console.log(item);
    this.state.infoText.unshift(item);
    if (this.state.infoText.length > 200) {
      this.state.infoText.splice(100);
    }
    this.setState({ infoText: this.state.infoText });
  }

  printSummary() {
    let evalCalls = 0;
    for (let playerMls of this.batchController.mls) {
      for (let ml of playerMls) {
        evalCalls += ml.nEvalCalls;
      }
    }
    this.addToLog(
      "Active games: " +
        this.activeGames +
        " Moves: " +
        this.totalMoves +
        " Evals: " +
        evalCalls +
        " Trains: " +
        this.trainingData.reduce((a, c) => a + c.length, 0) +
        " Time: " +
        (GetTime() - this.startTime)
    );
  }

  printPeriodicSummary() {
    if (!this.playing) return;
    this.printSummary();
    setTimeout(() => {
      this.printPeriodicSummary();
    }, 5000);
  }

  onClick(id) {
    this.state.playerTypes[id] =
      this.state.playerTypes[id] == "HUMAN"
        ? "MONKEY"
        : this.state.playerTypes[id] == "MONKEY"
        ? "COMPUTER"
        : this.state.playerTypes[id] == "COMPUTER"
        ? "REMOTE"
        : "HUMAN";
    this.setState({
      playerTypes: this.state.playerTypes,
    });
  }

  render() {
    let selectedModelKeys = {};
    this.state.selectedModels.map((models) => {
      models.map((name) => {
        if (name != "RANDOM") selectedModelKeys[name] = true;
      });
    });
    return (
      <div
        style={{
          position: "absolute",
          width: totalWidth,
          height: "100%",
          left: "50%",
          top: 0,
          //top: "50%",
          transform: "translateX(-50%)",
        }}
      >
        <div
          style={{
            marginTop: 20,
          }}
        >
          {"CURRENT GAME: " + (this.props.selectedGame || "")}
        </div>
        <div
          style={{
            bottom: 45,
            position: "absolute",
            width: "100%",
            left: 0,
            height: 600,
          }}
        >
          <InfoField
            text={
              <div>
                {this.state.infoText.map((text, id) => {
                  return <div key={id}>{text}</div>;
                })}
              </div>
            }
            onClear={() => {
              this.setState({ infoText: ["Logging is enabled."] });
            }}
          />
        </div>
        <div
          style={{
            position: "relative",
            width: 200,
            height: 50,
            display: "flex",
            justifyContent: "space-between",
            alignItems: "center",
            //border: "1px solid black",
            //display: "flex",
            //justifyContent: "center",
            //alignItems: "center",
          }}
        >
          {"TF BACKEND:"}
          <ToggleSwitch
            values={["webgl", "cpu" /*wasm*/]}
            selected={this.state.tfBackend}
            onSwitchValue={(val) => {
              this.onBackendToggled(val);
            }}
          />
        </div>
        <PlayerModels
          onUpdate={() => {
            this.setState({ selectedModels: this.state.selectedModels });
            this.state.selectedModels.map((models) => {
              models.map((name) => {
                if (name != "RANDOM") {
                  this.setState({ trainModel: name });
                  return;
                }
              });
            });
          }}
          playerTypes={this.state.playerTypes}
          modelList={this.props.modelList}
          selectedModels={this.state.selectedModels}
        />
        <div>
          <HighDiv
            text={this.state.playing ? "STOP" : "PLAY"}
            style={{
              marginTop: 10,
              width: 200,
              height: 50,
              backgroundColor: "rgb(200, 200, 200)",
              display: "flex",
              justifyContent: "center",
              alignItems: "center",
              cursor: "grab",
              marginBottom: 10,
            }}
            clickable={true}
            onClick={() => {
              if (!this.playing) {
                this.resetTrainingData();
                this.playBatch(this.state.nGames, () => {});
              } else {
                this.stop();
              }
            }}
          />
          <HighDiv
            text="LOOP"
            style={{
              //marginTop: 20,
              position: "relative",
              left: 210,
              width: 200,
              height: 50,
              backgroundColor: "rgb(200, 200, 200)",
              display: "flex",
              justifyContent: "center",
              alignItems: "center",
              cursor: "grab",
              marginBottom: 20,
            }}
            clickable={true}
            onClick={() => this.loop()}
          />
          <div
            style={{
              position: "absolute",
              top: 158,
              left: 0,
              width: 200,
              height: 50,
              marginBottom: 20,
            }}
          >
            <HighDiv
              text="TRAIN"
              style={{
                //marginTop: 20,
                position: "absolute",
                top: 0,
                left: 0,
                width: "calc(100% - 100px)",
                height: "100%",
                backgroundColor: "rgb(200, 200, 200)",
                display: "flex",
                justifyContent: "center",
                alignItems: "center",
                cursor: "grab",
                paddingRight: 100,
              }}
              clickable={true}
              onClick={() => this.train()}
            />
            <div
              style={{
                position: "absolute",
                right: 0,
                top: "25%",
                width: 110,
                height: "50%",
                display: "flex",
                justifyContent: "center",
                alignItems: "center",
                //backgroundColor: "yellow",
              }}
            >
              {Object.keys(selectedModelKeys).length == 0 ? (
                <div
                  style={{
                    position: "absolute",
                    outline: "2px solid rgb(158, 158, 158)",
                    right: 7,
                    height: 35,
                    width: 100,
                    display: "flex",
                    justifyContent: "center",
                    alignItems: "center",
                    backgroundColor: "rgb(255, 255, 255)",
                  }}
                >
                  {"No models"}
                </div>
              ) : (
                <select
                  style={{
                    width: 100,
                    height: 35,
                    border: undefined,
                    outline: "1px solid rgb(158, 158, 158)",
                  }}
                  onChange={(e) => {
                    this.setState({ trainModel: e.target.value });
                  }}
                  value={this.state.trainModel}
                >
                  {Object.keys(selectedModelKeys).map((option) => {
                    return <option key={"option_" + option}>{option}</option>;
                  })}
                </select>
              )}
            </div>
          </div>
          {Object.keys(this.state.inferenceProps).map((key) => {
            //let prop = updatableProps[key];
            return (
              <div
                key={key}
                style={{
                  width: "100%",
                  display: "flex",
                  justifyContent: "space-between",
                }}
              >
                {key + ": " + this.state.inferenceProps[key]}
                <input
                  type="text"
                  name="name"
                  value={this.state.inferenceProps[key]}
                  onChange={(event) => {
                    if (key == "nGames") {
                      this.setState({ nGames: parseInt(event.target.value) });
                    }
                    this.state.inferenceProps[key] = event.target.value;
                    this.setState({
                      inferenceProps: this.state.inferenceProps,
                    });
                  }}
                  style={{ width: 50, marginRight: 5 }}
                />
              </div>
            );
          })}
        </div>
      </div>
    );
  }
}
