import * as d3 from "d3";
import { Regl, Texture2D } from "regl";
import { colorMap, ColorScale } from "../chart/ColorMap";
import { Rectangle } from "../chart/Rectangle";
import { brightnessSquircle } from "../chart/SquircleScale";
import { Vec2, Vec4 } from "../math/Vec";
import { replaceUndefined } from "../util/Utils";
import * as ReglPerf from "./ReglPerf";
import { blendSrcAlpha, fullScreenTriangles } from "./WebGLUtil";

type NamedColorScale = "viridis" | "plasma";

export interface DensityMapArgs {
  colorScale?: ColorScale | NamedColorScale;
}

const defaults = {
  colorScale: d3.interpolateViridis as ColorScale | NamedColorScale,
};

export interface MapDensityProps {
  /**
   * To determine the max threshold for brightness scaling, use the percentile
   * value reported in the percentile texture multiplied by this scaling factor.
   * Defaults to 1.0 (i.e. use the percentile value unchanged).
   * This is so the slope-weighted line renderer can use a lower percentile
   * value which is more stable across zooms, and then make it bigger to get
   * a wider range.
   */
  scaleThreshold?: number;

  /** scale density values non-linearly (see SquircleScale) */
  brightness?: number;

  /** densities this value or less are not copied to the target FB (screen)
     (uses brightness scale adjusted density value) */
  copyAbove?: number;

  /**  fade out density overlay as plot grows sparse
   *
   * fade uses the blurred density percentile density value (e.g. p95) for the frame
   * to determine the fade level. the percentile value comes from the percentile texture.
   *
   * @param fade[0] start: density at which the fade out begins
   * @param fade[1] end: density at which the fade out ends (must be smaller than start)
   *
   * densities greater than start are mapped to a full brightness color
   * densities less than end are not copied to the destination at all
   * densities between start and end are blended into the destination, proportionally
   */
  fadeRange?: Vec2;

  /** 0-255 RBGA color for background */
  background?: Vec4;
}

/** @return a regl command to copy pixels, transforming scalar intensity values
 * by a non-linear function to add contrast, and then mapping the scalar result to
 * colors through a color map.
 *
 * @param density map of density per pixel
 * @param percentileTexture single pixel texture containing statistics about the density distribution
 * @param clipRect clip drawing on left and right. Specified in screen pixel coordinates.
 *
 * (Frame statistics are summarized by framereduce prior to
 * this shader and passed in the meanTexture and percentileTexture).
 *
 */
export function densityMapCmd(
  regl: Regl,
  density: Texture2D,
  clipRect: Rectangle,
  percentileTexture: Texture2D,
  fadePercentileTexture: Texture2D,
  args: DensityMapArgs = {}
): (props?: MapDensityProps) => void {
  const { colorScale } = replaceUndefined(args, defaults);
  const colorRamp = colorRampTexture(regl, colorScale);

  const cmd = regl({
    vert: `
      precision highp float;
      uniform sampler2D percentileTexture;
      uniform sampler2D fadePercentileTexture;
      uniform vec2 fadeRange;
      uniform bool fadeEnabled;
      uniform bool fadeTextureEnabled;
      uniform float scaleThreshold;

      attribute vec2 position;

      varying vec2 uv;
      varying float threshold;
      varying float fade;
          
      float fetchPercentile(sampler2D bucketStats) {
        vec4 minMax = texture2D(bucketStats, vec2(.5, .5));
        float min = minMax.r;
        float max = minMax.g;
        float percentileValue = (min + max) / 2.0;
        return percentileValue;
      }

      void main() {
        gl_Position = vec4(position, 1.0, 1.0); // position range (-1, -1)
        uv = 0.5 * (position + 1.0);            // uv range (0, 1)
        float percentileValue = fetchPercentile(percentileTexture) * scaleThreshold;

        // fade out the density color overlay as the graph gets sparser
        if (fadeEnabled) {
          float fadePercentile = percentileValue;
          if (fadeTextureEnabled) {
            fadePercentile = fetchPercentile(fadePercentileTexture);
          }
          float offsetInRange = fadePercentile - fadeRange[1];
          float scaled = offsetInRange / (fadeRange[0] - fadeRange[1]);
          fade = clamp(scaled, 0.0, 1.0);
        } else {
          fade = 1.0;
        }

        threshold = percentileValue;
      }`,
    frag: `
      precision highp float;
      uniform sampler2D srcTexture;
      uniform sampler2D colorRamp;
      uniform float copyAbove;  // only map&copy densities that are greater than this value

      uniform float clipLeft;  // discard pixels outside this area
      uniform float clipRight; // discard pixels outside this area

      uniform float squircleN; // powers for the squircle power scaling of intensity 
      uniform float squircleMinv;

      uniform vec4 background; // color when density = zero (only relevant if colorAbove < 0)

      varying vec2 uv;
      varying float threshold;  // densities at or above threshold are mapped to the hottest color
      varying float fade;

      float minFloat32 = 1E-37; // (small enough. e.g. 9E-38 seems to equal zero.)

      // map an intensity to a color using the colorRamp texture
      vec4 mapColor(float intensity) {
        if (intensity > copyAbove) {
          vec3 color = texture2D(colorRamp, vec2(intensity, 0.0)).rgb;
          return vec4(color, fade);
        } else {
          discard;
        }
      }

      void main() {
        if (gl_FragCoord.x < clipLeft || gl_FragCoord.x > clipRight) {
          discard;
        }
        float rawDensity = texture2D(srcTexture, uv).r;
        float normalDensity = min(rawDensity / threshold, 1.0);  
        float inner = 1.0 - pow(1.0 - normalDensity, squircleN);
        float scaled = pow(inner, squircleMinv);
        if (rawDensity > 0.0 && (squircleMinv > 10000.0 || scaled == 0.0)) {
          scaled = minFloat32;
        } 
        gl_FragColor = mapColor(scaled);
      }`,
    attributes: {
      position: fullScreenTriangles,
    },
    uniforms: {
      srcTexture: density,
      copyAbove: (_ctx, props: MapDensityCmdProps) => props.copyAbove,
      colorRamp,
      percentileTexture,
      fadePercentileTexture,
      clipLeft: clipRect.left,
      clipRight: clipRect.right,
      squircleN: (_ctx, props: MapDensityCmdProps) => props.squircleN,
      squircleMinv: (_ctx, props: MapDensityCmdProps) => props.squircleMinv,
      fadeRange: (_ctx, props: MapDensityCmdProps) => props.fadeRange,
      fadeEnabled: (_ctx, props: MapDensityCmdProps) => props.fadeEnabled,
      background: (_ctx, props: MapDensityCmdProps) => props.backgroundGlColor,
      scaleThreshold: (_ctx, props: MapDensityCmdProps) => props.scaleThreshold,
      fadeTextureEnabled: fadePercentileTexture !== percentileTexture,
    },
    depth: { enable: false, mask: false },
    primitive: "triangles",
    count: 6,
    framebuffer: null,
    blend: blendSrcAlpha(),
  });
  ReglPerf.registerCmd(cmd, "mapDensity");

  function drawDensity(props?: MapDensityProps): void {
    const squircleProps = intensityCurve(props?.brightness);
    const cssColor = props?.background || ([255, 255, 255, 255] as Vec4);
    const backgroundGlColor = cssColor.map((b) => b / 255) as Vec4;

    const { fadeRange = [0, 0], copyAbove = 0, scaleThreshold = 1.0 } = props || {};

    const cmdProps: MapDensityCmdProps = {
      ...squircleProps,
      scaleThreshold,
      fadeRange,
      copyAbove,
      fadeEnabled: fadeRange !== undefined,
      backgroundGlColor,
    };
    cmd(cmdProps);
  }

  return drawDensity;
}

interface SquircleParams {
  squircleN: number;
  squircleMinv: number;
}

/** @return powers for the squircle scale calculation from a single 'brightness' value between
 * 0 and 20.
 *
 * See SquircleScale.ts for details.
 */
function intensityCurve(brightness?: number): SquircleParams {
  if (!brightness) {
    return {
      squircleN: 1,
      squircleMinv: 20000, // forces mapping to coolest color
    };
  }
  const { m, n } = brightnessSquircle(brightness);
  const mInv = m === 0 ? 1000 : 1 / m; // protect against div / 0h

  return {
    squircleN: n,
    squircleMinv: mInv,
  };
}

interface MapDensityCmdProps {
  squircleN: number;
  squircleMinv: number;
  fadeRange: Vec2;
  copyAbove: number;
  fadeEnabled: boolean;
  scaleThreshold: number;
  backgroundGlColor: Vec4; // gl color for background
}

function colorRampTexture(
  regl: Regl,
  colorScale: ColorScale | NamedColorScale
): Texture2D {
  let colors: ColorScale;
  if (colorScale === "viridis") {
    colors = d3.interpolateViridis;
  } else if (colorScale === "plasma") {
    colors = d3.interpolatePlasma;
  } else {
    colors = colorScale;
  }
  const colorData = colorMap(colors);
  const colorRamp = regl.texture({
    width: colorData.length / 3, // 3 color components
    height: 1,
    format: "rgb",
    min: "linear",
    mag: "linear",
    data: colorData,
  });
  return colorRamp;
}
