import _ from "lodash";
import { Framebuffer2D, Regl, Texture2D } from "regl";
import * as ReglPerf from "./ReglPerf";
import { fullScreenTriangles } from "./WebGLUtil";

export interface StatReplaceNanProps {
  src: Texture2D;
  dst: Framebuffer2D;
  summaryLevel: number;
}

export interface StatReplaceNanCmd {
  (props: StatReplaceNanProps): void;
}

interface CmdProps {
  src: Texture2D;
  dst: Framebuffer2D;
  dstStart: number;
  dstScale: number;
  srcStart: number;
  srcScale: number;
}
/** Create a partial redution of a series by outputting a series half
 * as long with NaN or negative Y values replaced where possible.
 *
 * The reduction is:
 *   (a,N) => a
 *   (N,b) => b
 *   (a,b) => a
 *   (N,N) => N
 *
 * The reduction ping pongs between two buffers, writing the output
 * into successive sections of buffers. The position in the buffers
 * is controlled by a level parameter to the draw command which
 * is translated into srcStart, srcScale and dstStart, dstScale for the shader.
 *
 * A final pass (coalesceFill) then traverses the reduced buffers to
 * emit results.
 */
export function statReplaceNaN(regl: Regl): StatReplaceNanCmd {
  const cmd = regl({
    vert: `
        precision highp float;
        attribute vec2 position;
        uniform float dstScale;   // scale position by this amount to map to target buffer.
        uniform float dstStart;   // offset scaled position by this amount to map to target buffer
        uniform float srcScale;   // read from this fraction smaller part of the src buffer
        uniform float srcStart;   // read starting at this position in src buffer
        uniform float srcPixelWidth; // pixel width in (0, 1) texture coordinates in src buffer
        varying float u;
        varying float dd;
        const float scaleToSrc = 2.0;
            
        void main() {
            u = 0.5 * (position.x + 1.0);             // uv range (0, 1) 
            u = u * srcScale + srcStart;
            u -= srcPixelWidth / scaleToSrc;          // align on first pixel of parent pair
            /* choose position in target buffer
             *
             * the target buffer layout is:
             *           2 2 2 2 2 2 2 8 8 32 ...      // for odd summary levels 
             * or        4 4 4 4 4 4 4 16 16 64 ...    // for even summary levels
             */
            float x = position.x * dstScale + dstStart;
            gl_Position = vec4(x, position.y, 1.0, 1.0); // position range (-1, -1)
            dd = u;
        }`,
    frag: `
        precision highp float;
        varying float u; // x coordinate in texture 
        uniform sampler2D src;
        uniform float srcPixelWidth;
        varying float dd; // debug value
        const float Flag = -1e38;

        bool isNan(float f) {
          return !(f < 0.0 || 0.0 < f || f == 0.0 );
        } 

        void main() { 
          vec2 a = texture2D(src, vec2(u, .5)).xy;
          float bu = u + srcPixelWidth;
          vec2 b = texture2D(src, vec2(bu, .5)).xy;

          vec2 value = a;
          if (isNan(a.y) || a.y < 0.0) {
            value = b;
          }
          gl_FragColor = vec4(value, 0.0, 1.0);
        }`,
    uniforms: {
      src: (_ctx, props: CmdProps) => props.src,
      srcPixelWidth: (_ctx, props: CmdProps) => 1 / props.src.width,
      dstScale: (_ctx, props: CmdProps) => props.dstScale,
      dstStart: (_ctx, props: CmdProps) => props.dstStart,
      srcScale: (_ctx, props: CmdProps) => props.srcScale,
      srcStart: (_ctx, props: CmdProps) => props.srcStart,
    },
    attributes: {
      position: fullScreenTriangles,
    },
    framebuffer: (_ctx: any, props: CmdProps) => props.dst,
    depth: { enable: false, mask: false },
    primitive: "triangles",
    count: 6,
  });

  ReglPerf.registerCmd(cmd, "stats-replace-nan");

  return function draw(props: StatReplaceNanProps): void {
    const { src, dst, summaryLevel } = props;
    const { srcStart, srcScale, dstStart, dstScale } = summaryStartScale(summaryLevel);
    const params: CmdProps = { src, dst, srcStart, srcScale, dstStart, dstScale };
    cmd(params);
  };
}

/* summary level n in 1/(2^n) * width
 * e.g. are summarizing 1/2 original size, 1/4 original size, 1/8, ...?
 * the target buffer layout is:
 *           1 1 1 1  1 1 1 1  1 1 1 1  1 1 1 1  3 3 3 3  5    // odd buffer
 *           2 2 2 2  2 2 2 2  4 4                             // even buffer
 */
export function summaryStartScale(level: number): {
  srcStart: number;
  srcScale: number;
  unitStart: number;
  dstStart: number;
  dstScale: number;
} {
  const { scale: dstScale, start: dstStart, unitStart } = startAndScale(level);
  const { scale: srcScale, unitStart: srcStart } = startAndScale(level - 1);

  return { srcScale, srcStart, dstScale, dstStart, unitStart };
}

function startAndScale(level: number): {
  unitStart: number;
  start: number;
  scale: number;
} {
  console.assert(_.isInteger(level));
  const odd = level & 0x1;
  let unitStart;
  let unitScale;
  if (level === 0) {
    unitStart = 0;
    unitScale = 1;
  } else if (odd) {
    const n = (level - 1) / 2;
    unitStart = (2 / 3) * (1 - Math.pow(4, -n));
    unitScale = 1 / Math.pow(2, level);
  } else {
    const n = (level - 2) / 2;
    unitStart = 1 / 3 - Math.pow(4, -n) / 3;
    unitScale = 1 / Math.pow(2, level);
  }

  const scale = unitScale;
  const start = scale + unitStart * 2 - 1;

  return { scale, start, unitStart };
}
