import * as tf from "@tensorflow/tfjs";
// ml = model loader
// getModelScoreBatch(batch) => [{id,OutputAdapter(output)}]
// batch = [id, InputAdapter(state)]

// this.nEvalCalls - number of times getModel is called

export default class TfModelLoader {
  constructor(tfmodel, batchSize) {
    this.batchSize = batchSize; // -1 = eval each call, 0 = eval based on time, n = accumulate until batch reaches n (but still use timeout)
    this.tfmodel = tfmodel;
    this.inputs = [];
    this.cbMap = [];
    this.flushing = false;
    this.nEvalCalls = 0;
    console.log(
      "constructed tfmodel loader with model",
      tfmodel.name,
      "batchSize:",
      this.batchSize
    );
  }

  getModelScoreBatchAsync(toEval, cb) {
    this.nEvalCalls++;
    if (this.batchSize == undefined || this.batchSize == -1) {
      this.getModelScoreBatchAsyncImmediate(toEval, cb);
      return;
    }
    this.addToQueue(toEval, cb);
  }

  addToQueue(toEval, cb) {
    if (toEval.length == 0) {
      console.log("ERROR!! toEval is empty!");
      cb([]);
      return;
    }
    //console.log("adding to q");
    let cbItem = { ids: [], inputIds: [], cb: cb };
    for (let item of toEval) {
      cbItem.ids.push(item.id);
      cbItem.inputIds.push(this.inputs.length);
      this.inputs.push(item.input.input);
    }
    this.cbMap.push(cbItem);

    if (this.batchSize != 0 && this.inputs.length >= this.batchSize) {
      this.flushQueue();
    } else {
      // FIXME: add timeout to flush q if no new inputs have been added
      if (this.flushTimeout) clearTimeout(this.flushTimeout);
      this.flushTimeout = setTimeout(() => {
        this.flushIfNotChanged(this.inputs.length);
      }, 20); // FIXME: different timeout?
    }
  }

  flushIfNotChanged(prevLength) {
    // FIXME: check if stalling possible in some corner case
    this.flushTimeout = null;
    // FIXME: can this condition even fail?
    if (this.inputs.length == prevLength) this.flushQueue();
  }

  flushQueue() {
    if (this.flushing) {
      console.log("q is already flushing, ignore");
      return;
    }
    if (this.inputs.length == 0) {
      console.log("inputs empty");
      return;
    }

    //console.log("flushing the queue of size", this.inputs.length);
    // NOTE: Although now this function is sync, we still do this for future.
    this.flushing = true;

    let tensor = tf.tensor4d(this.inputs);
    let pred = this.tfmodel.predict(tensor);

    //console.log(pred);

    let values = pred[0].arraySync();
    let policies = pred[1].arraySync();

    //console.log(JSON.stringify([values, policies]));

    for (let item of this.cbMap) {
      let ret = [];
      for (let i = 0; i < item.inputIds.length; i++) {
        ret.push({
          id: item.ids[i],
          output: [values[item.inputIds[i]], policies[item.inputIds[i]]],
        });
      }
      // NOTE: async call to avoid nested calls to flushQ
      setTimeout(() => item.cb(ret), 0);
    }

    this.inputs = [];
    this.cbMap = [];
    this.flushing = false;
  }

  getModelScoreBatch(toEval) {
    let ret = [];
    // FIXME: implement
    // FIXME: items are arrays, we need to convert to tensors, and also
    // load the backend, and also apply model
    let inputs = [];
    for (let item of toEval) {
      inputs.push(item.input.input);
    }
    //console.log(inputs);

    // FIXME: we need to know if tensor is 4d or 3d ....
    // FIXME: we need to get shape from data? from game?
    //let shape = inputs[0].dims;//this.tfmodel.feedInputShapes[0];
    //shape[0] = inputs.length;
    let tensor = tf.tensor4d(inputs);
    let pred = this.tfmodel.predict(tensor);

    //console.log(pred);

    let value = pred[0].arraySync();
    let policy = pred[1].arraySync();

    //console.log(inputs, value, policy);

    for (let i = 0; i < toEval.length; i++) {
      // {id, input}
      ret.push({
        id: toEval[i].id,
        output: [value[i], policy[i]],
      });
    }
    //console.log("prediction:", JSON.stringify(ret));
    return ret;
  }

  getModelScoreBatchAsyncImmediate(toEval, cb) {
    //console.log("computing immediately");
    // FIXME: need to implement this
    let res = this.getModelScoreBatch(toEval);
    cb(res);
  }
}
