import { Tensor } from "@tensorflow/tfjs";
import * as tf from "@tensorflow/tfjs";
import MnistData from "../../modules/mnist_data";
import p5 from "p5";
import {
  OUTPUT_UNITS,
  LABEL_ARRAY,
  GRID_HEIGHT,
  GRID_WIDTH,
} from "../../globals";

import {
  TRAINING_DATA_AMOUNT,
  TEST_DATA_AMOUNT,
  BATCH_SIZE,
  EPOCHS,
} from "./constants";
import { N_CLASSES } from "../../modules/mnist_data";

import { grid } from "../../sketches/inputSketch";
import { changePredictedChar } from "../../userInteraction/interaction";

export let model: tf.Sequential = createConvModel();

let mnistData = new MnistData();
let x: tf.Tensor;
let y: tf.Tensor;

export let predictionOutputIndex: number;

let reloaded = false;

export const cnnAiSketch = (p: p5) => {
  p.disableFriendlyErrors = true;
  p.setup = () => {
    p.frameRate(10);

    model.summary();
    // console.log("WEIGHTS: ", model.getWeights());
    // console.log(model.getWeights()[0].transpose().arraySync());

    model.compile({
      optimizer: "rmsprop",
      loss: "categoricalCrossentropy",
      metrics: ["accuracy"],
    });

    mnistData
      .load(TRAINING_DATA_AMOUNT, TEST_DATA_AMOUNT)
      .then((result) => training());
  };

  p.draw = () => {
    manualTesting();
    // console.table(grid);
    //@ts-ignore
    changePredictedChar(LABEL_ARRAY[predictionOutputIndex]);
    clearGPURam();
  };
};

function createDenseModel() {
  const model = tf.sequential();

  model.add(tf.layers.flatten({ inputShape: [GRID_HEIGHT, GRID_WIDTH, 1] }));

  model.add(
    tf.layers.dense({ units: 42, activation: "relu", inputShape: [784] })
  );

  model.add(tf.layers.dense({ units: OUTPUT_UNITS, activation: "softmax" }));

  console.log("Dense model created!");

  return model;
}

function createTestModel() {
  const model = tf.sequential();

  model.add(tf.layers.flatten({ inputShape: [GRID_HEIGHT, GRID_WIDTH, 1] }));

  // model.add(tf.layers.dense({units: 42, activation: 'relu', inputShape: [784]}));

  model.add(tf.layers.dense({ units: OUTPUT_UNITS, activation: "softmax" }));

  console.log("Dense model created!");

  return model;
}

function createConvModel() {
  const model = tf.sequential();

  model.add(
    tf.layers.conv2d({
      inputShape: [GRID_HEIGHT, GRID_WIDTH, 1],
      kernelSize: 3,
      filters: 16,
      activation: "relu",
    })
  );

  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));

  model.add(
    tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: "relu" })
  );

  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));

  model.add(
    tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: "relu" })
  );

  model.add(tf.layers.flatten({}));

  model.add(tf.layers.dense({ units: 64, activation: "relu" }));

  model.add(tf.layers.dense({ units: OUTPUT_UNITS, activation: "softmax" }));

  console.log("CNN model created!");

  return model;
}

function training() {
  console.log("Training on: " + TRAINING_DATA_AMOUNT);

  [x, y] = mnistData.getTrainData();
  // x = x.reshape([TRAINING_DATA_AMOUNT, PIXEL_COUNT]);

  const emptyOutputUnitsExtension = tf.zeros([
    TRAINING_DATA_AMOUNT,
    OUTPUT_UNITS - N_CLASSES,
  ]);
  const axis = 1;

  // console.log("x: " + x.shape);
  // console.log("y: " + y.shape);
  y = tf.concat([y, emptyOutputUnitsExtension], axis);
  // console.log("y (after concat): " + y.shape);

  model
    .fit(x, y, {
      batchSize: BATCH_SIZE,
      epochs: EPOCHS,
    })
    .then((result) => {
      console.log("training done!");
      testing();
      // drawMiniWeightSketches();
    });
}

function testing() {
  let countOfTestSamples = 500;
  countOfTestSamples = Math.min(countOfTestSamples, TEST_DATA_AMOUNT);

  console.log("Testing on: " + countOfTestSamples);

  let [input, expectedOutput] = mnistData.getTestData(countOfTestSamples);
  // input = input.reshape([countOfTestSamples, PIXEL_COUNT]);
  // console.log("input: " + input.shape);

  // console.log(input.arraySync());

  let output = model.predict(input);

  // @ts-ignore
  let pred_indeces = output.transpose().argMax();
  let test_indeces = expectedOutput.transpose().argMax();

  let num_correct = pred_indeces.equal(test_indeces).sum().dataSync()[0];
  let accuracy = Math.round((num_correct / countOfTestSamples) * 10000) / 100;

  // console.log("Total: " + num_test_samples);
  console.log("Correct: " + num_correct);
  console.log("Accuracy: " + accuracy + "%");
}

export function manualTraining(userInput: string) {
  let userInputLabelIndex = LABEL_ARRAY.indexOf(userInput);

  if (!userInputLabelIndex) {
    console.log("please insert label to textbox to train Ai");
    return;
  }

  let rotatedGrid = rotate2DArray(grid);
  let mirroredGrid = mirror2DArray(rotatedGrid);

  let input = tf.tensor(mirroredGrid);

  input = input.reshape([1, GRID_HEIGHT, GRID_WIDTH, 1]);

  // input.print(true);

  // input = input.reshape([1, 784]);

  // ======= BUILD SOLUTION TENSOR =======

  let solutionAsArray: number[] = [];

  for (let i = 0; i < OUTPUT_UNITS; i++) {
    solutionAsArray[i] = 0;

    if (i == userInputLabelIndex) solutionAsArray[i] = 1;
  }

  let solutionAsTensor = tf.tensor(solutionAsArray);

  // ======= END BUILD SOLUTION TENSOR

  const axis = 0;
  solutionAsTensor = solutionAsTensor.expandDims(axis);
  // solutionAsTensor.print(true);

  model.fit(input, solutionAsTensor, { batchSize: 1, epochs: 10 });
}

function manualTesting() {
  let rotatedGrid = rotate2DArray(grid);
  let mirroredGrid = mirror2DArray(rotatedGrid);

  predictionOutputIndex = tf.tidy(() => {
    let input = tf.tensor(mirroredGrid);

    input = input.reshape([1, GRID_HEIGHT, GRID_WIDTH, 1]);

    return (model.predict(input) as Tensor).argMax(1).dataSync()[0];
  });
}

export function rotate2DArray(array: any[]) {
  let rotatedArray: any[] = [];

  for (let i = array.length - 1; i >= 0; i--) {
    let counter = 0;
    array[i].forEach((element) => {
      if (rotatedArray.length <= counter) rotatedArray.push([]);

      rotatedArray[counter].push(element);
      counter += 1;
    });
  }

  return rotatedArray;
}

export function mirror2DArray(array) {
  let mirroredArray: any[] = [];
  let indexCounter = 0;

  for (let y = 0; y < array.length; y++) {
    mirroredArray.push([[]]);
    for (let x = 0; x < array[0].length; x++) {
      indexCounter = array[0].length - 1 - x;
      mirroredArray[y][indexCounter] = array[y][x];
    }
  }

  return mirroredArray;
}

const clearGPURam = () => {
  //@ts-ignore
  // console.log(tf.memory().numBytesInGPU);
  //@ts-ignore
  if (tf.memory().numBytesInGPU > 500000000 && reloaded == false) {
    reloaded = true;
    //@ts-ignore
    window.location.reload(true);
  }
};
