import { makeStyles } from "@material-ui/styles";
import * as d3 from "d3";
import React, { useContext, useEffect, useMemo } from "react";
import { SquircleScale } from "../../chart/SquircleScale";
import { LegendMargin } from "./DensityLegend";
import { SizeContext, SizedContainer } from "./SizedContainer";

const useStyles = makeStyles({
  axisContainer: {
    gridArea: "axis",
    overflow: "hidden",
  },
  maxString: {
    display: "inline-block",
    paddingTop: 5,
    paddingLeft: 8,
    gridArea: "max",
  },
});

const contentsStyles = makeStyles({
  svgAxis: {
    display: "inline-block",
    verticalAlign: "top",
  },
  axisGroup: {
    fontSize: 14,
  },
});

const formatTick = d3.format(",.2~r");

export type DensityAxisProps = DensityThresholds & LegendMargin;

export interface DensityThresholds {
  min: number;
  max: number;
  threshold: number;
  brightness?: number;
}

export function DensityAxis(props: DensityAxisProps): JSX.Element {
  const classes = useStyles();
  const { max } = props,
    maxStr = formatTick(max);

  return (
    <>
      <span className={classes.maxString}>({maxStr})</span>
      <SizedContainer className={classes.axisContainer}>
        <AxisContents {...props} />
      </SizedContainer>
    </>
  );
}

function AxisContents(props: DensityAxisProps): JSX.Element {
  const classes = contentsStyles();
  const [svgWidth, svgHeight] = useContext(SizeContext)!;
  const { min, legendMargin, brightness, threshold } = props;
  const axisHeight = svgHeight - (legendMargin.top + legendMargin.bottom);

  const axisScale = useMemo(() => SquircleScale(brightness), [brightness]),
    axis = d3.axisRight(axisScale).ticks(5),
    scaleRef = React.useRef(null);

  axis.tickFormat(formatTick);

  useEffect(() => {
    const node = scaleRef.current! as SVGGElement,
      select = d3.select(node);

    // -1 so the tick mark aligns with bottom of strip
    axisScale.range([axisHeight - 1, 0]).domain([min, threshold]);

    axis(select);
  }, [axisScale, axisHeight, min, threshold, axis]);

  return (
    <>
      <svg className={classes.svgAxis} width={svgWidth} height={svgHeight}>
        <g
          ref={scaleRef}
          className={classes.axisGroup}
          transform={translate(0, legendMargin.top)}
        />
      </svg>
    </>
  );
}

function translate(x: number, y: number): string {
  return `translate(${x},${y})`;
}
