import { Tensor } from "@tensorflow/tfjs";
import * as tf from "@tensorflow/tfjs";
import MnistData from "../../modules/mnist_data";
import p5 from "p5";
import { PIXEL_COUNT, OUTPUT_UNITS, LABEL_ARRAY } from "../../globals";
import { N_CLASSES } from "../../modules/mnist_data";

import { TRAINING_DATA_AMOUNT, TEST_DATA_AMOUNT } from "./constants";

import { grid } from "../../sketches/inputSketch";
import { changePredictedChar } from "../../userInteraction/interaction";
import { simpleDrawMiniWeightSketches } from "./weights";

export let model: tf.Sequential = tf.sequential();
let mnistData = new MnistData();
let x: tf.Tensor;
let y: tf.Tensor;

export let predictionOutputIndex: number;
let reloaded = false;

export const simpleAiSketch = (p: p5) => {
  p.disableFriendlyErrors = true;
  p.setup = () => {
    model.add(
      tf.layers.dense({
        units: OUTPUT_UNITS,
        inputShape: [784],
        activation: "sigmoid",
        useBias: false,
      })
    );

    model.summary();

    model.compile({
      optimizer: tf.train.sgd(5),
      loss: "meanSquaredError",
    });

    mnistData
      .load(TRAINING_DATA_AMOUNT, TEST_DATA_AMOUNT)
      .then((result) => training());
  };

  p.draw = () => {
    manualTesting();
    // console.table(grid);
    //@ts-ignore
    changePredictedChar(LABEL_ARRAY[predictionOutputIndex]);
    clearGPURam();
  };
};

function training() {
  [x, y] = mnistData.getTrainData();
  x = x.reshape([TRAINING_DATA_AMOUNT, PIXEL_COUNT]);

  const emptyOutputUnitsExtension = tf.zeros([
    TRAINING_DATA_AMOUNT,
    OUTPUT_UNITS - N_CLASSES,
  ]);
  const axis = 1;
  y = tf.concat([y, emptyOutputUnitsExtension], axis);

  model
    .fit(x, y, {
      batchSize: 1,
      epochs: 1,
    })
    .then((result) => {
      console.log("traing done");
      simpleDrawMiniWeightSketches();
    });
}

function testing() {
  let countOfTestSamples = 10;
  let [input, expectedOutput] = mnistData.getTestData(countOfTestSamples);
  input = input.reshape([countOfTestSamples, PIXEL_COUNT]);
  input.print(true);
  let output = model.predict(input);
  (output as Tensor).print(true);
  expectedOutput.print(true);
}

export function manualTraining(userInput: string) {
  let userInputLabelIndex = LABEL_ARRAY.indexOf(userInput);

  if (!userInputLabelIndex) {
    console.log("please insert label to textbox to train Ai");
    return;
  }

  let rotatedGrid = rotate2DArray(grid);
  let mirroredGrid = mirro2DArray(rotatedGrid);

  let input = tf.tensor(mirroredGrid);

  input.print(true);

  input = input.reshape([1, 784]);

  // ======= BUILD SOLUTION TENSOR =======

  let solutionAsArray: number[] = [];

  for (let i = 0; i < OUTPUT_UNITS; i++) {
    solutionAsArray[i] = 0;

    if (i == userInputLabelIndex) solutionAsArray[i] = 1;
  }

  let solutionAsTensor = tf.tensor(solutionAsArray);

  // ======= END BUILD SOLUTION TENSOR

  const axis = 0;
  solutionAsTensor = solutionAsTensor.expandDims(axis);
  solutionAsTensor.print(true);

  model.fit(input, solutionAsTensor, { batchSize: 1, epochs: 10 });
}

function manualTesting() {
  let rotatedGrid = rotate2DArray(grid);
  let mirroredGrid = mirro2DArray(rotatedGrid);

  predictionOutputIndex = tf.tidy(() => {
    let input = tf.tensor(mirroredGrid);

    input = input.reshape([1, 784]);

    return (model.predict(input) as Tensor).argMax(1).dataSync()[0];
  });
}

export function rotate2DArray(array: any[]) {
  let rotatedArray: any[] = [];

  for (let i = array.length - 1; i >= 0; i--) {
    let counter = 0;
    array[i].forEach((element) => {
      if (rotatedArray.length <= counter) rotatedArray.push([]);

      rotatedArray[counter].push(element);
      counter += 1;
    });
  }

  return rotatedArray;
}

export function mirro2DArray(array) {
  let mirroredArray: any[] = [];
  let indexCounter = 0;

  for (let y = 0; y < array.length; y++) {
    mirroredArray.push([]);
    for (let x = 0; x < array[0].length; x++) {
      indexCounter = array[0].length - 1 - x;
      mirroredArray[y][indexCounter] = array[y][x];
    }
  }

  return mirroredArray;
}

const clearGPURam = () => {
  //@ts-ignore
  // console.log(tf.memory().numBytesInGPU);
  //@ts-ignore
  if (tf.memory().numBytesInGPU > 145000000 && reloaded == false) {
    reloaded = true;
    //@ts-ignore
    window.location.reload(true);
  }
};
