import p5 from "p5";
import {
  CANVAS_HEIGHT,
  CANVAS_WIDTH,
  WEIGHT_GRID_HEIGHT,
  WEIGHT_GRID_WIDTH,
} from "../../globals";
import {
  predictionOutputIndex,
  model,
  rotate2DArray,
  mirror2DArray,
} from "./ai";
import * as tf from "@tensorflow/tfjs";

let outputToVisualize = 1;

let brightness = 400;
let weightGrid: any[] = [];

export let weightArray: any;

let weightGridtileSizeX: number;
let weightGridtileSizeY: number;

export const cnnWeightSketch = (p: p5) => {
  p.setup = () => {
    p.createCanvas(CANVAS_WIDTH, CANVAS_HEIGHT);
    p.frameRate(10);

    weightGridtileSizeX = p.width / WEIGHT_GRID_WIDTH;
    weightGridtileSizeY = p.height / WEIGHT_GRID_HEIGHT;

    weightGrid = Array.from({ length: WEIGHT_GRID_WIDTH }, () => Array.from({ length: WEIGHT_GRID_WIDTH }));
  };

  p.draw = () => {
    outputToVisualize = predictionOutputIndex;
    getWeightsOfOutputUnitIntoWeightGrid();
    let rotatedAndMirroredgrid = rotateAndMirrorArray(weightGrid);
    drawGrid(p, rotatedAndMirroredgrid, weightGridtileSizeX);

    // Write "Weights not available" to canvas
    p.textSize(40);
    p.textStyle(p.BOLD);
    p.fill('PURPLE');
    p.textAlign(p.CENTER, p.CENTER);
    p.text("Weights\nnot available\nright now...\n:(((", p.width / 2, p.height / 2); 
    p.fill('WHITE');
  };

  function getWeightsOfOutputUnitIntoWeightGrid() {
    tf.tidy(() => {
      weightArray = model.getWeights()[0].transpose().arraySync();
      let weigthsOfOutputUnit = weightArray[outputToVisualize];
      weightGrid = mapGridToCanvas(weigthsOfOutputUnit)
    });
  }


  
};

export function mapGridToCanvas(weightLabel: any[]) {

  let g: any[] = Array.from({ length: WEIGHT_GRID_WIDTH }, () => Array.from({ length: WEIGHT_GRID_HEIGHT }));

  for (let x = 0; x < WEIGHT_GRID_WIDTH; x++) {
    for (let y = 0; y < WEIGHT_GRID_HEIGHT; y++) {
      g[x][y] = weightLabel[y + x * WEIGHT_GRID_WIDTH];
    }
  }
  return g;
}

export function rotateAndMirrorArray(array: any[]) {
  let rotatedArray = rotate2DArray(array);
  let mirroredArray = mirror2DArray(rotatedArray);
  return mirroredArray;
}

export function drawGrid(p: p5, array: any[], tileSize: number) {
  for (let x = 0; x < array.length; x++)
    for (let y = 0; y < array[0].length; y++) {
      p.fill(array[x][y] * brightness);
      p.rect(tileSize * x, tileSize * y, tileSize);
    }
}