import _ from "lodash";
import { FrameReduceSnippets } from "./FrameReduce";

const bucketsFromCounts: string = (() => {
  let bucketDex = 0;
  const bucketGroups = 6;
  const bucketsPerGroup = 4;
  const lines = _.times(bucketGroups).flatMap((bucketGroup) =>
    _.times(bucketsPerGroup).map((indexInBucket) => {
      const line = `buckets[${bucketDex}] = counts${bucketGroup}[${indexInBucket}];`;
      bucketDex += 1;
      return line;
    })
  );
  return lines.join("\n");
})();

/** @return 6 glsl functions: incrementGroup0 through incrementGroup5 */
const incrementGroupN: string = (() => {
  return _.times(6)
    .map(
      (i) => `
    void incrementGroup${i}(int index) {
      if (index == 0) {
        counts${i}[0] += 1.0;
      } else if (index == 1) {
        counts${i}[1] += 1.0;
      } else if (index == 2) {
        counts${i}[2] += 1.0;
      } else if (index == 3) {
        counts${i}[3] += 1.0;
      } 
    }
  `
    )
    .join("\n");
})();

/** calculate a percentile from the nonzero values in the selected color channel */
export function percentileStartBuckets(
  channel: "r" | "g" | "b" | "a"
): FrameReduceSnippets {
  return {
    vertexDeclare: `
    uniform sampler2D bucketStats; 

    varying float min;
    varying float max;
  `,
    vertexMain: `
    vec4 stats = texture2D(bucketStats, vec2(.5, .5));
    min = stats.r;
    max = stats.g;
  `,
    declare: `
    #extension GL_EXT_draw_buffers : require
    varying float min;
    varying float max;
    float range;
    vec4 counts0;
    vec4 counts1;
    vec4 counts2;
    vec4 counts3;
    vec4 counts4;
    vec4 counts5;

    ${incrementGroupN}

    void incrementBucketCount(float value) {
      const float numBuckets = 24.0;
      const float bucketsPerGroup = 4.0;
      const int numBucketGroups = 6;

      float bucket = floor(numBuckets * (value - min) / range);
      int bucketGroup = int(bucket / bucketsPerGroup);
      int indexInBucket = int(mod(bucket, bucketsPerGroup));

      if (bucketGroup < 3) {    
        if (bucketGroup == 0) {
          incrementGroup0(indexInBucket);
        } else if (bucketGroup == 1) {
          incrementGroup1(indexInBucket);
        } else if (bucketGroup == 2) {
          incrementGroup2(indexInBucket);
        }
      } else {
        if (bucketGroup == 3) {
          incrementGroup3(indexInBucket);
        } else if (bucketGroup == 4) {
          incrementGroup4(indexInBucket);
        } else if (bucketGroup == 5) {
          incrementGroup5(indexInBucket);
        }
      }
    }
  `,
    initialize: `
    range = max - min;
    counts0 = vec4(0.0);
    counts1 = vec4(0.0);
    counts2 = vec4(0.0);
    counts3 = vec4(0.0);
    counts4 = vec4(0.0);
    counts5 = vec4(0.0);
  `,
    accumulate: `
    float value = srcTexel.${channel};
    if (value > 0.0) {
      incrementBucketCount(value);
    }
  `,
    assignColor: `
    gl_FragData[0] = counts0;
    gl_FragData[1] = counts1;
    gl_FragData[2] = counts2;
    gl_FragData[3] = counts3;
    gl_FragData[4] = counts4;
    gl_FragData[5] = counts5;
  `,
  };
}

// for the shaders that accumulate intermediate results
export const percentileAccumulate: FrameReduceSnippets = {
  pragma: `
    #extension GL_EXT_draw_buffers : require
  `,
  declare: `
    vec4 counts0;
    vec4 counts1;
    vec4 counts2;
    vec4 counts3;
    vec4 counts4;
    vec4 counts5;
    uniform sampler2D src1;
    uniform sampler2D src2;
    uniform sampler2D src3;
    uniform sampler2D src4;
    uniform sampler2D src5;
  `,
  initialize: `
    counts0 = vec4(0.0);
    counts1 = vec4(0.0);
    counts2 = vec4(0.0);
    counts3 = vec4(0.0);
    counts4 = vec4(0.0);
  `,
  accumulate: `
    counts0 += srcTexel;
    counts1 += texture2D(src1, spot);
    counts2 += texture2D(src2, spot);
    counts3 += texture2D(src3, spot);
    counts4 += texture2D(src4, spot);
    counts5 += texture2D(src5, spot);
  `,
  assignColor: `
    gl_FragData[0] = counts0;
    gl_FragData[1] = counts1;
    gl_FragData[2] = counts2;
    gl_FragData[3] = counts3;
    gl_FragData[4] = counts4;
    gl_FragData[5] = counts5;
  `,
};

export const percentileLast: FrameReduceSnippets = {
  pragma: `
    #extension GL_EXT_draw_buffers : require
  `,
  declare: `
    uniform sampler2D bucketStats; 

    vec4 counts0;
    vec4 counts1;
    vec4 counts2;
    vec4 counts3;
    vec4 counts4;
    vec4 counts5;
    uniform sampler2D src1;
    uniform sampler2D src2;
    uniform sampler2D src3;
    uniform sampler2D src4;
    uniform sampler2D src5;

    vec4 stats;
    float min;
    float max;
    float percentileCount;
    float percentileNth;

    const int numBuckets = 24;
    const int bucketsPerGroup = 4;
    const int bucketGroups = 6;

    // return the min,max,percentileCount for the next precentile pass
    vec3 nextBuckets() {
      int foundBucket = -1;  
      float buckets[numBuckets];
      ${bucketsFromCounts}
      
      float foundCount = percentileCount;     // total count from buckets prior to current
      float proposedCount = percentileCount;  // count including current bucket

      // check which bucket contains the target nth element
      for (int i = 0; i < numBuckets; i++) {
        if (foundBucket < 0) {
          proposedCount += buckets[i];
          if (proposedCount >= percentileNth) {
            foundBucket = i;
          } else {
            foundCount = proposedCount;
          }
        }
      }

      float range = max - min;
      float nextMin = min + range * float(foundBucket) / float(numBuckets);
      float nextMax = min + range * float(foundBucket + 1) / float(numBuckets);
      return vec3(nextMin, nextMax, foundCount);
    }
  `,
  initialize: `
    stats = texture2D(bucketStats, vec2(.5, .5));
    min = stats.r;
    max = stats.g;
    percentileCount = stats.b;
    percentileNth = stats.a;

    counts0 = vec4(0.0);
    counts1 = vec4(0.0);
    counts2 = vec4(0.0);
    counts3 = vec4(0.0);
    counts4 = vec4(0.0);
  `,
  accumulate: `
    counts0 += srcTexel;
    counts1 += texture2D(src1, spot);
    counts2 += texture2D(src2, spot);
    counts3 += texture2D(src3, spot);
    counts4 += texture2D(src4, spot);
    counts5 += texture2D(src5, spot);
  `,
  assignColor: `
    gl_FragData[6] = vec4(nextBuckets(), percentileNth);

    gl_FragData[0] = counts0; // these aren't used by the next pass, but are handy for debugging
    gl_FragData[1] = counts1;
    gl_FragData[2] = counts2;
    gl_FragData[3] = counts3;
    gl_FragData[4] = counts4;
    gl_FragData[5] = counts5;
  `,
};
