import _ from "lodash";
import { default as ndarray, NdArray } from "ndarray";
import ndshow from "ndarray-show";
import { Framebuffer2D, Regl } from "regl";
import sprintfjs from "sprintf-js";
import { Vec2 } from "../math/Vec";
import { dLog } from "../util/DebugLog";
import { fetchPixels } from "./WebGLUtil";

const sprintf = sprintfjs.sprintf;

export function printPixels(
  pixels: Uint8Array | Float32Array,
  size: [number, number],
  component: "r" | "g" | "b" | "a" = "r"
): void {
  const channel = channelFromPixels(pixels, size, component),
    show = ndshow;

  console.log(show(channel));
}

function channelFromPixels(
  pixels: Uint8Array | Float32Array,
  size: [number, number],
  component: "r" | "g" | "b" | "a" = "r"
): NdArray<Uint8Array | Float32Array> {
  const pixelDex = "rgba".split("").findIndex((c) => c === component);

  const nd = ndarray(pixels, [size[1], size[0], 4]).step(-1), // reverse so pixels go from top left
    channel = nd.pick(null, null, pixelDex);
  return channel;
}

/** log some statistics (percentiles, mean, nonzero count ) from an image */
export function pixelStats(
  pixels: Float32Array | Uint8Array,
  component: "r" | "g" | "b" | "a" = "r",
  percentiles = [0.5, 0.9, 0.95, 0.99, 1.0]
): string {
  const valueArray = oneColorArray(pixels, component);
  const percents = pixelPercentiles(valueArray, percentiles);
  const means = pixelMeans(valueArray);

  return `${percents}\n   ${means}`;
}

function pixelMeans(values: number[]): string {
  const nonZero = values.filter((v) => v != 0);
  const nzSum = _.sum(nonZero);
  const nzCount = nonZero.length;
  const nzMean = nzSum / nzCount;
  const count = values.length;
  const mean = _.sum(values) / count;
  const result = [
    `mean: ${np(mean)}`,
    `count: ${count}`,
    `nzZero.count: ${nzCount}`,
    `nzZero.sum: ${np(nzSum)}`,
    `nZmean: ${np(nzMean)}`,
  ];
  return result.join("  ");
}

function oneColorArray(
  pixels: Float32Array | Uint8Array,
  component: "r" | "g" | "b" | "a" = "r"
): number[] {
  const pixelDex = "rgba".split("").findIndex((c) => c === component);
  const count = Math.floor(pixels.length / 4);
  const pixels1d = ndarray(pixels, [count, 4]);
  const colorChannel = pixels1d.pick(null, pixelDex);
  const valueArray = _.times(count).map((i) => colorChannel.get(i));

  return valueArray;
}

function pixelPercentiles(values: number[], percentiles: number[]): string {
  const sorted = values.slice().sort((a, b) => a - b);
  const pValues = percentiles.map((p) => {
    const pIndex = Math.floor(sorted.length * p);
    const index = Math.min(pIndex, sorted.length - 1);
    return np(sorted[index]);
  });
  const pNames = percentiles.map((p) => "p" + Math.floor(p * 100));
  const strs = _.zip(pNames, pValues).map(([name, value]) => {
    return `${name}: ${value}`;
  });
  const report = strs.join("  ");
  return report;
}

/** short for number print.
 * a concise way to print a fractional value */
export function np(n: number): string {
  return sprintf("%.4f", n);
}

export function countOnesByRow(
  pixels: Uint8Array,
  size: [number, number],
  component: "r" | "g" | "b" | "a" = "r"
): void {
  const channel = channelFromPixels(pixels, size, component);

  for (let row = 0; row < channel.shape[1]; row++) {
    const ones = rowOnes(channel, row);
    console.log(row, ones);
  }
}
function rowOnes(channel: NdArray<Float32Array | Uint8Array>, rowNumber: number): number {
  const row = channel.pick(null, rowNumber),
    arr = [];
  for (let i = 0; i < row.shape[0]; i++) {
    arr.push(row.get(i));
  }
  return _.filter(arr, (v) => v === 255).length;
}

export function debugPrintFB(
  regl: Regl,
  fb: Framebuffer2D | null,
  size: Vec2,
  label?: string,
  component?: "r" | "g" | "b" | "a"
): void {
  const pixels = fetchPixels(regl, fb);
  if (label) {
    console.log(`${label}:`);
  }
  const colorComponent = component || "r";
  printPixels(pixels, size, colorComponent);
}

export function debugPercentile(
  regl: Regl,
  bucketFBs: Framebuffer2D[],
  prefix = ""
): void {
  const buckets = bucketFBs.flatMap((fb) => [...fetchPixels(regl, fb)]);
  const sum = _.sum(buckets);
  dLog(`buckets (${prefix})`, { buckets, sum });
}

export function debugBucketStats(
  regl: Regl,
  bucketStatsFB: Framebuffer2D,
  prefix = ""
): void {
  const pixels = fetchPixels(regl, bucketStatsFB);
  const [min, max, percentileCount, percentileNth] = pixels;
  const center = (min + max) / 2;

  dLog(`(${prefix})`, {
    min,
    max,
    center,
    percentileCount,
    percentileNth,
  });
}

export function debugFade(
  regl: Regl,
  nzStatsFB: Framebuffer2D,
  bucketStatsFB: Framebuffer2D,
  bucketStatsFBG?: Framebuffer2D
): void {
  const [sum, nzCount] = fetchPixels(regl, nzStatsFB);
  const [pMin, pMax] = fetchPixels(regl, bucketStatsFB);
  const percentile = (pMax + pMin) / 2;
  const nzMean = sum / nzCount;
  let vars: Record<string, unknown> = { percentile, nzMean };
  if (bucketStatsFBG) {
    const [pMinG, pMaxG] = fetchPixels(regl, bucketStatsFBG);
    const percentileG = (pMaxG + pMinG) / 2;
    vars = { percentileG, ...vars };
  }

  dLog("fade", vars);
}

export function debugNzStats(regl: Regl, nzStatsFB: Framebuffer2D, label = ""): void {
  const pixels = fetchPixels(regl, nzStatsFB);
  const sum = pixels[0];
  const nzCount = pixels[1];
  const nzMean = sum / nzCount;
  const min = pixels[2];
  const max = pixels[3];

  dLog(`nzStats${label}:`, { nzMean, sum, nzCount, min, max });
}
