import p5 from "p5";
import {
  CANVAS_HEIGHT,
  CANVAS_WIDTH,
  GRID_HEIGHT,
  GRID_WIDTH,
  PIXEL_COUNT,
} from "../../globals";
import {
  predictionOutputIndex,
  model,
  rotate2DArray,
  mirro2DArray,
} from "./ai";
import * as tf from "@tensorflow/tfjs";

let outputToVisualize = 1;

let brightness = 400;
let weightGrid: any[] = [];

export let weightArray: any;

let weightGridtileSizeX: number;
let weightGridtileSizeY: number;

export const simpleWeightSketch = (p: p5) => {
  p.setup = () => {
    p.createCanvas(CANVAS_WIDTH, CANVAS_HEIGHT);

    weightGridtileSizeX = p.width / GRID_WIDTH;
    weightGridtileSizeY = p.height / GRID_HEIGHT;

    weightGrid = Array.from({ length: 28 }, () => Array.from({ length: 28 }));
  };

  p.draw = () => {
    outputToVisualize = predictionOutputIndex;
    getWeightsOfOutputUnitIntoWeightGrid();
    let rotatedAndMirroredgrid = rotateAndMirrorArray(weightGrid);
    drawGrid(p, rotatedAndMirroredgrid, weightGridtileSizeX);
  };

  function getWeightsOfOutputUnitIntoWeightGrid() {
    tf.tidy(() => {
      weightArray = model.getWeights()[0].transpose().arraySync();
      let weigthsOfOutputUnit = weightArray[outputToVisualize];
      weightGrid = mapGridToCanvas(weigthsOfOutputUnit)
    });
  }


  
};

export function mapGridToCanvas(weightLabel: any[]) {

  let g: any[] = Array.from({ length: GRID_WIDTH }, () => Array.from({ length: GRID_HEIGHT }));

  for (let x = 0; x < GRID_WIDTH; x++) {
    for (let y = 0; y < GRID_HEIGHT; y++) {
      g[x][y] = weightLabel[y + x * GRID_WIDTH];
    }
  }

  return g
}

export function rotateAndMirrorArray(array: any[]) {
  let rotatedArray = rotate2DArray(array);
  let mirroredArray = mirro2DArray(rotatedArray);
  return mirroredArray;
}

export function drawGrid(p: p5, array: any[], tileSize: number) {
  for (let x = 0; x < array.length; x++)
    for (let y = 0; y < array[0].length; y++) {
      p.fill(array[x][y] * brightness);
      p.rect(tileSize * x, tileSize * y, tileSize);
    }
}