import { makeStyles } from "@material-ui/core";
import * as d3 from "d3";
import _ from "lodash";
import React, { useCallback, useContext, useEffect, useRef } from "react";
import { ColorScale } from "../../chart/ColorMap";
import { LegendMargin } from "./DensityLegend";
import { SizeContext, SizedContainer } from "./SizedContainer";

const containerStyles = makeStyles({
  stripContainer: {
    display: "inline-block",
    gridArea: "strip",
    overflow: "hidden",
  },
});

export const ColorStrip = React.memo(ColorStripInternal);

export function ColorStripInternal(props: LegendMargin): JSX.Element {
  const classes = containerStyles();

  return (
    <SizedContainer className={classes.stripContainer}>
      <StripContents {...props} />
    </SizedContainer>
  );
}

const useStyles = makeStyles({
  colorStrip: {
    display: "inline-block",
    verticalAlign: "top",
    boxSizing: "border-box",
    position: "relative",
    top: (props: LegendMargin) => props.legendMargin.top,
  },
});

function StripContents(props: LegendMargin): JSX.Element {
  const classes = useStyles(props);
  const colorScale = d3.interpolateViridis;
  const [width, height] = useContext(SizeContext)!;
  const { legendMargin } = props;
  const canvasHeight = height - (legendMargin.top + legendMargin.bottom);
  const canvasRef = useRef<HTMLCanvasElement>();

  useEffect(() => {
    if (canvasRef.current) {
      drawStrip(canvasRef.current, colorScale);
    }
  }, [width, canvasHeight, colorScale]);

  const withCanvasRef = useCallback(
    (canvasNode: HTMLCanvasElement) => {
      canvasRef.current = canvasNode;
      if (canvasNode) {
        drawStrip(canvasNode, colorScale);
      }
    },
    [colorScale]
  );

  return (
    <canvas
      ref={withCanvasRef}
      className={classes.colorStrip}
      width={width}
      height={canvasHeight}
    />
  );
}

/**
 * Draw a colored strip for this color scale in a canvas element.
 */
function drawStrip(canvasNode: HTMLCanvasElement, colorScale: ColorScale): void {
  const ctx = canvasNode.getContext("2d")!,
    width = canvasNode.width,
    height = canvasNode.height;

  ctx.clearRect(0, 0, width, height);
  ramp();

  /** fill in the color ramp */
  function ramp(): void {
    for (const y of _.range(0, height)) {
      const saturation = 1 - y / height,
        color = colorScale(saturation);

      ctx.fillStyle = color;
      ctx.fillRect(0, y, width, 1);
    }
  }
}
