import dynamic from 'next/dynamic';
import { ScatterData, Layout, ColorScale, PlotMouseEvent, PlotData } from 'plotly.js';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useRecoilValue } from 'recoil';
import { Container } from './styles';
import {
  INCHES_TO_METERS_FACTOR,
  METERS_TO_INCHES_FACTOR,
  METERS_TO_MILLIMETERS_FACTOR,
  MILLIMETERS_TO_METERS_FACTOR,
} from '@/utils/unitSystem';
import { UnitSystemEnum } from '@/__generated__/graphql';
import { unitSystem as unitSystemState } from '@/components/Analysis/state';
import { UnlockCrossSectionToggle } from './UnlockCrossSectionToggle';
import { CrossSectionPlaneSelector, CrossSectionPlaneEnum } from './CrossSectionPlaneSelector';

// https://github.com/plotly/react-plotly.js/issues/273
const Plot = dynamic(() => import('react-plotly.js'), { ssr: false });

export type Point = { x: number; y: number; z: number };

// Height must be in meters
export type PointWithHeight = Point & { height: number };

type Props = {
  pointsArray: PointWithHeight[];
  fittedSurfacePointsArray: Point[];
  onSelectPoint?: (point: PointWithHeight) => void; // Height returned is in meters
};

const USE_METHOD_A = false;
const METHOD_A_DISTANCE_LIMIT = 0.001;

const METHOD_B_NUM_POINTS = 300;
const METHOD_B_MAX_DISTANCE_LIMIT = 0.001;

const Z_ASPECT = 0.5;
const CONFIG = { responsive: true, displaylogo: false };

const COLOR_SCALE_VIRIDIS: ColorScale = [
  [0, 'rgb(68,1,84)'],
  [0.1, 'rgb(71,39,117)'],
  [0.2, 'rgb(62,72,135)'],
  [0.3, 'rgb(49,102,141)'],
  [0.4, 'rgb(38,130,141)'],
  [0.5, 'rgb(36,157,136)'],
  [0.6, 'rgb(55,181,120)'],
  [0.7, 'rgb(109,204,88)'],
  [0.8, 'rgb(176,221,49)'],
  [1, 'rgb(253,231,37)'],
];

const remapColorScale = (
  colorScale: Array<[number, string]>,
  intensityMin: number,
  intensityMax: number,
  fixedRangeMin: number,
  fixedRangeMax: number
) => {
  /*
  Remap input colourScale into 3 zones.
  * variable scale - minimum value in data to fixedRangeMin
  * fixed scale - fixedRangeMin to fixedRangeMax (colours consistently match to input values)
  * variable scale - fixedRangeMax to maximum value in the data
  Colour intensity scale is broken into 3 regions
      |  variable scale  |        fixed scale      |  variable scale |
      |                  |        input            |                 |
      ----------------------------------------------------------------
  intensityMin        fixedRangeMin             fixedRangeMax          intensityMax
                       -2mm                       2mm
  */
  // Ensure the full scale starts at lower of measuremntMin intensityMin
  // i.e. protect against intensityMin being larger than measurementMin
  const fullScaleValueMin = Math.min(intensityMin, fixedRangeMin);
  const fullScaleValueMax = Math.max(intensityMax, fixedRangeMax);

  // full range of values on the color scale
  // used to convert to ratios as used by the chart
  const colourFullRange = fullScaleValueMax - fullScaleValueMin;
  const fixedRangeColorScale = (fixedRangeMax - fixedRangeMin) / colourFullRange;
  const measurementMinRatio = (fixedRangeMin - fullScaleValueMin) / colourFullRange;
  const measurementMaxRatio = (fixedRangeMax - fullScaleValueMin) / colourFullRange;

  const veryLow = 'rgb(0,0,0)';
  const veryHigh = 'rgb(255,0,0)';

  // only allocate variable scales if the input values go beyond the fixed limits
  // 0.0001 is used to make a sharp color change - the color scale needs some increment to work
  const variableScaleLow: Array<[number, string]> =
    measurementMinRatio - 0.0001 > 0
      ? [
          [0, veryLow],
          [measurementMinRatio - 0.0001, veryHigh],
        ]
      : [];
  const variableScaleHigh: Array<[number, string]> =
    measurementMaxRatio + 0.0001 < 1
      ? [
          [measurementMaxRatio + 0.0001, veryHigh],
          [1, veryLow],
        ]
      : [[1, veryLow]];

  const fixedScale: Array<[number, string]> = colorScale.map((mark) => [
    measurementMinRatio + mark[0] * fixedRangeColorScale,
    mark[1],
  ]);
  const colorScaleView: Array<[number, string]> = [
    ...variableScaleLow,
    ...fixedScale,
    ...variableScaleHigh,
  ];

  return {
    colorScaleView,
    fullScaleValueMin,
    fullScaleValueMax,
  };
};

type VectorisedPoints = {
  x: Float32Array;
  y: Float32Array;
  z: Float32Array;
};

type VectorisedPointsWithHeight = VectorisedPoints & {
  height: number[];
};

type ValueRange = { min: number; max: number };
type PointsRange = {
  x: ValueRange;
  y: ValueRange;
  z: ValueRange;
};
type PointsWithHeightRange = PointsRange & {
  height: ValueRange;
};

const floatArrayMin = (values: Float32Array): number => {
  return values.reduce((a, b) => Math.min(a, b), Number.POSITIVE_INFINITY);
};
const floatArrayMax = (values: Float32Array): number => {
  return values.reduce((a, b) => Math.max(a, b), Number.NEGATIVE_INFINITY);
};

const arrayMin = (values: number[]): number => {
  return values.reduce((a, b) => Math.min(a, b), Number.POSITIVE_INFINITY);
};
const arrayMax = (values: number[]): number => {
  return values.reduce((a, b) => Math.max(a, b), Number.NEGATIVE_INFINITY);
};
const arrayRange = (values: number[]): ValueRange => {
  return { min: arrayMin(values), max: arrayMax(values) };
};
const floatArrayRange = (values: Float32Array): ValueRange => {
  return { min: floatArrayMin(values), max: floatArrayMax(values) };
};

const vectorisePointsWithHeight = (points: PointWithHeight[]): VectorisedPointsWithHeight => {
  const length = points.length;
  const valuesX = new Float32Array(length);
  const valuesY = new Float32Array(length);
  const valuesZ = new Float32Array(length);
  const valuesH: number[] = [];
  points.forEach(({ x, y, z, height }, index) => {
    valuesX[index] = x;
    valuesY[index] = y;
    valuesZ[index] = z;
    valuesH.push(height);
  });
  return { x: valuesX, y: valuesY, z: valuesZ, height: valuesH };
};

const vectorisePoints = (points: Point[]): VectorisedPoints => {
  const length = points.length;
  const valuesX = new Float32Array(length);
  const valuesY = new Float32Array(length);
  const valuesZ = new Float32Array(length);
  points.forEach(({ x, y, z }, index) => {
    valuesX[index] = x;
    valuesY[index] = y;
    valuesZ[index] = z;
  });
  return { x: valuesX, y: valuesY, z: valuesZ };
};

const calcPointsRange = (points: VectorisedPointsWithHeight): PointsRange => {
  const { x, y, z } = points;
  return {
    x: floatArrayRange(x),
    y: floatArrayRange(y),
    z: floatArrayRange(z),
  };
};

const calcPointsWithHeightRange = (points: VectorisedPointsWithHeight): PointsWithHeightRange => {
  const { height } = points;
  return {
    ...calcPointsRange(points),
    height: arrayRange(height),
  };
};

const scaleTranslatePointsWithHeight = (
  points: VectorisedPointsWithHeight,
  pointsRange: PointsRange,
  scale: number
) => {
  const x = new Float32Array(points.x);
  const y = new Float32Array(points.x);
  const z = new Float32Array(points.x);
  const height: number[] = [];
  points.x.forEach((valueX, index) => {
    x[index] = (points.x[index] - pointsRange.x.min) * scale;
    y[index] = (points.y[index] - pointsRange.y.min) * scale;
    z[index] = (points.z[index] - pointsRange.z.min) * scale;
    height.push(points.height[index] * scale);
  });
  return { x, y, z, height };
};

const scaleTranslateFittedPoints = (
  points: VectorisedPoints,
  pointsRange: PointsRange,
  scale: number
) => {
  const x = new Float32Array(points.x);
  const y = new Float32Array(points.x);
  const z = new Float32Array(points.x);
  points.x.forEach((valueX, index) => {
    x[index] = (points.x[index] - pointsRange.x.min) * scale;
    y[index] = (points.y[index] - pointsRange.y.min) * scale;
    z[index] = (points.z[index] - pointsRange.z.min) * scale;
  });
  return { x, y, z };
};

const findCrossSectionPoints = (
  coord: 'X' | 'Y',
  crossSectionPointIndex: number,
  points0: VectorisedPoints,
  points: VectorisedPoints,
  scaledPoints: VectorisedPoints,
  pointsRange: PointsRange,
  extraProps: Partial<ScatterData>
): {
  crossSectionPointIndexes: number[];
  crossSectionPoints: Partial<ScatterData>;
} => {
  const parallelCoordValues0 = coord === 'X' ? points0.y : points0.x;

  const parallelCoordValue = parallelCoordValues0[crossSectionPointIndex];

  const coordValues1 = coord === 'X' ? points.x : points.y;
  const coordValues2 = coord === 'X' ? points.y : points.x;

  const crossSectionPointIndexes: number[] = [];

  if (USE_METHOD_A) {
    // Method A: select all pointsArray inside range close to the section plane
    const minParallelCoord = parallelCoordValue - METHOD_A_DISTANCE_LIMIT;
    const maxParallelCoord = parallelCoordValue + METHOD_A_DISTANCE_LIMIT;

    coordValues2.forEach((value, index) => {
      if (value > minParallelCoord && value < maxParallelCoord) {
        crossSectionPointIndexes.push(index);
      }
    });
  } else {
    // Method B: select best point for each of N voxels to represent profile
    // choose point closest to voxel Z-line, limit maximal distance
    const pointIndexes = new Int32Array(METHOD_B_NUM_POINTS);
    pointIndexes.fill(-1);
    const pointDistances = new Float32Array(METHOD_B_NUM_POINTS);
    const coordRange1 = coord === 'X' ? pointsRange.x : pointsRange.y;
    const coordRange2 = coord === 'X' ? pointsRange.y : pointsRange.x;
    const voxelScale = (coordRange1.max - coordRange1.min) / METHOD_B_NUM_POINTS;
    const voxelScaleInverse = voxelScale === 0 ? 0 : 1 / voxelScale;
    coordValues1.forEach((coordValue1, index) => {
      const coordValue2 = coordValues2[index];
      // if (coordValue < coordRange.min || coordValue > coordRange.max) return;
      if (coordValue2 < coordRange2.min || coordValue2 > coordRange2.max) return;
      const voxelI = Math.trunc((coordValue1 - coordRange1.min) * voxelScaleInverse);
      if (voxelI < 0 || voxelI >= METHOD_B_NUM_POINTS) return;
      const voxelValue = coordRange1.min + (voxelI + 0.5) * voxelScale;
      const d0 = voxelValue - coordValue1;
      const d1 = parallelCoordValue - coordValue2;
      const squaredDistance = d0 * d0 + d1 * d1;
      if (squaredDistance < METHOD_B_MAX_DISTANCE_LIMIT * METHOD_B_MAX_DISTANCE_LIMIT) {
        const hasPoint = pointIndexes[voxelI] >= 0;
        if (!hasPoint || squaredDistance < pointDistances[voxelI]) {
          pointIndexes[voxelI] = index;
          pointDistances[voxelI] = squaredDistance;
        }
      }
    });

    pointIndexes.forEach((pointInsdex) => {
      if (pointInsdex >= 0) {
        crossSectionPointIndexes.push(pointInsdex);
      }
    });
  }

  // console.log(`Found ${crossSectionPointIndexes.length} profile points of ${points.x.length}`);

  const pointsX = new Float32Array(crossSectionPointIndexes.length);
  const pointsY = new Float32Array(crossSectionPointIndexes.length);
  const pointsZ = new Float32Array(crossSectionPointIndexes.length);
  crossSectionPointIndexes.forEach((pointIndex, index) => {
    pointsX[index] = scaledPoints.x[pointIndex];
    pointsY[index] = scaledPoints.y[pointIndex];
    pointsZ[index] = scaledPoints.z[pointIndex];
  });

  const crossSectionPoints: Partial<PlotData> = {
    x: coord === 'X' ? pointsX : pointsY,
    y: pointsZ,
    xaxis: 'x2',
    yaxis: 'y2',
    mode: 'markers',
    ...extraProps,
  };

  // console.log({ name: extraProps.name, x: crossSectionPoints.x, y: crossSectionPoints.y });

  return { crossSectionPointIndexes, crossSectionPoints };
};

const GraphBase = ({ pointsArray, fittedSurfacePointsArray, onSelectPoint }: Props) => {
  const [crossSectionIsUnlocked, setCrossSectionIsUnlocked] = useState(false);
  const unitSystem = useRecoilValue(unitSystemState);
  const unitScaleFromRawToDisplay = useMemo(() => {
    return unitSystem === UnitSystemEnum.Imperial
      ? METERS_TO_INCHES_FACTOR
      : METERS_TO_MILLIMETERS_FACTOR;
  }, [unitSystem]);

  const unitScaleFromDisplayToRaw = useMemo(() => {
    return unitSystem === UnitSystemEnum.Imperial
      ? INCHES_TO_METERS_FACTOR
      : MILLIMETERS_TO_METERS_FACTOR;
  }, [unitSystem]);

  const unitString = useMemo(() => {
    return unitSystem === UnitSystemEnum.Imperial ? '"' : 'mm';
  }, [unitSystem]);

  const points = useMemo(() => vectorisePointsWithHeight(pointsArray), [pointsArray]);
  const pointsRange = useMemo(() => calcPointsRange(points), [points]);

  const scaledPoints = useMemo(
    () => scaleTranslatePointsWithHeight(points, pointsRange, unitScaleFromRawToDisplay),
    [points, pointsRange, unitScaleFromRawToDisplay]
  );
  const scaledPointsRange = useMemo(() => calcPointsWithHeightRange(scaledPoints), [scaledPoints]);

  const fittedPoints = useMemo(
    () => vectorisePoints(fittedSurfacePointsArray),
    [fittedSurfacePointsArray]
  );
  const scaledFittedPoints = useMemo(() => {
    return scaleTranslateFittedPoints(fittedPoints, pointsRange, unitScaleFromRawToDisplay);
  }, [fittedPoints, pointsRange, unitScaleFromRawToDisplay]);

  const [crossSectionPointIndex, setCrossSectionPointIndex] = useState<number>();

  const findClosestPoint = useCallback(
    (points0: VectorisedPoints, needX: number, needY: number) => {
      let bestPointIndex = undefined;
      let bestDistance = Number.POSITIVE_INFINITY;

      points0.x.forEach((pointX, index) => {
        const dx = pointX - needX;
        const dy = points0.y[index] - needY;
        const distance = dx * dx + dy * dy;
        if (distance < bestDistance) {
          bestPointIndex = index;
          bestDistance = distance;
        }
      });
      return bestPointIndex;
    },
    []
  );

  const findCenterPoint = useCallback(() => {
    const centerX = (pointsRange.x.min + pointsRange.x.max) * 0.5;
    const centerY = (pointsRange.y.min + pointsRange.y.max) * 0.5;

    return findClosestPoint(points, centerX, centerY);
  }, [points, pointsRange, findClosestPoint]);

  useEffect(
    () => setCrossSectionPointIndex((current) => current ?? findCenterPoint()),
    [findCenterPoint]
  );

  const { coordMax, heightMin, heightMax } = useMemo(
    () => ({
      coordMax: Math.max(scaledPointsRange.x.max, scaledPointsRange.y.max, scaledPointsRange.z.max),
      heightMin: scaledPointsRange.height.min,
      heightMax: scaledPointsRange.height.max,
    }),
    [scaledPointsRange]
  );

  const { colorScaleView, fullScaleValueMin, fullScaleValueMax } = useMemo(() => {
    const fixedRangeMin = -1 * (2 / 1000) * unitScaleFromRawToDisplay;
    const fixedRangeMax = (2 / 1000) * unitScaleFromRawToDisplay;
    return remapColorScale(COLOR_SCALE_VIRIDIS, heightMin, heightMax, fixedRangeMin, fixedRangeMax);
  }, [unitScaleFromRawToDisplay, heightMin, heightMax]);

  // suddenly crossSectionIsUnlocked is not propagated in handleClick callback, using a reference instead
  const crossSectionIsUnlockedReference = useRef(false);
  useEffect(() => {
    crossSectionIsUnlockedReference.current = crossSectionIsUnlocked;
    // console.log('crossSectionIsUnlockedReference.current', crossSectionIsUnlockedReference.current);
  }, [crossSectionIsUnlocked]);

  const handleClick = useCallback(
    (event: Readonly<PlotMouseEvent>) => {
      if (!crossSectionIsUnlockedReference.current) return;
      if (!onSelectPoint) return;
      if (event.points.length === 0) return;

      const pointData = event.points[0];
      if (
        typeof pointData.x === 'number' &&
        typeof pointData.y === 'number' &&
        typeof pointData.z === 'number' &&
        typeof pointData.customdata === 'number'
      ) {
        onSelectPoint({
          x: pointData.x,
          y: pointData.y,
          z: pointData.z,
          height: pointData.customdata * unitScaleFromDisplayToRaw,
        });
      }
      if (pointData.x === scaledPoints.x[pointData.pointNumber]) {
        setCrossSectionPointIndex(pointData.pointNumber);
      } else if (
        pointData.x === scaledFittedPoints.x[pointData.pointNumber] &&
        typeof pointData.x === 'number' &&
        typeof pointData.y === 'number'
      ) {
        const pointIndex = findClosestPoint(scaledPoints, pointData.x, pointData.y);
        if (pointIndex) {
          setCrossSectionPointIndex(pointIndex);
        }
      }
    },
    [onSelectPoint, unitScaleFromDisplayToRaw, scaledPoints, scaledFittedPoints, findClosestPoint]
  );

  const extractCrossSectionData = useCallback(
    (coord: 'X' | 'Y') => {
      if (!crossSectionPointIndex) return;

      const { crossSectionPointIndexes, crossSectionPoints } = findCrossSectionPoints(
        coord,
        crossSectionPointIndex,
        points,
        points,
        scaledPoints,
        pointsRange,
        {
          name: 'point cloud'.toLocaleUpperCase(),
          marker: {
            size: 5,
            colorscale: colorScaleView,
            cmin: fullScaleValueMin,
            cmax: fullScaleValueMax,
          },
        }
      );

      const heights: number[] = [];
      let maxHeight = 0;
      crossSectionPointIndexes.forEach((pointIndex) => {
        maxHeight = Math.max(maxHeight, pointsArray[pointIndex].height);
        heights.push(scaledPoints.height[pointIndex]);
      });

      crossSectionPoints.customdata = heights;
      crossSectionPoints.marker!.color = heights;

      const { crossSectionPoints: crossSectionPlane } = findCrossSectionPoints(
        coord,
        crossSectionPointIndex,
        points,
        fittedPoints,
        scaledFittedPoints,
        pointsRange,
        {
          name: 'fitted surface'.toLocaleUpperCase(),
          marker: { size: 3, color: 'lightgray' },
        }
      );

      const rangeX = coord === 'X' ? scaledPointsRange.x : scaledPointsRange.y;
      const rangeY = scaledPointsRange.z;

      const plotData = [crossSectionPlane, crossSectionPoints];
      const plotLayout: Partial<Layout> = {
        margin: {
          l: 0,
          r: 0,
          b: 0,
          t: 0,
          pad: 0,
        },
        autosize: true,

        legend: {
          orientation: 'h',
          yanchor: 'bottom',
          xanchor: 'left',
          valign: 'top',
          font: { size: 12 },

          itemsizing: 'constant',
          itemwidth: 9,
        },

        yaxis2: {
          domain: [0.2, 0.9],
          anchor: 'x2',
          range: [rangeY.min * 0.95, rangeY.max * 1.05],
          // title: 'z',
        },
        xaxis2: {
          domain: [0.2, 0.9],
          anchor: 'y2',
          range: [rangeX.min * 0.97, rangeX.max * 1.03],
          // title: `${coord}-Z plane cross-section (blister height: ${Math.round(maxHeight * 1000) / 1000})`,
          // title: coord.toLowerCase(),
        },
      };
      return { data: plotData, layout: plotLayout };
    },
    [
      colorScaleView,
      fullScaleValueMin,
      fullScaleValueMax,
      crossSectionPointIndex,
      pointsArray,
      points,
      pointsRange,
      scaledPoints,
      scaledPointsRange,
      fittedPoints,
      scaledFittedPoints,
    ]
  );

  const {
    layout,
    data,
  }: {
    layout: Partial<Layout>;
    data: Partial<ScatterData>[];
  } = useMemo(() => {
    const heightsData = scaledPoints.height;

    const plotData: Partial<ScatterData> & { projection?: object } = {
      // name: 'point cloud',
      type: 'scatter3d',
      x: scaledPoints.x,
      y: scaledPoints.y,
      z: scaledPoints.z,
      customdata: heightsData,
      hovertemplate: `%{customdata:.3f}${unitString}<extra></extra>`,
      mode: 'markers',
      marker: {
        size: 12,
        color: heightsData, // what values the color's based on
        colorbar: {
          lenmode: 'fraction',
          ticksuffix: unitString,
          len: 0.8,
          thicknessmode: 'fraction',
          thickness: 0.08,
          tickfont: { size: 10, family: 'Roboto', color: 'gray' },
        },
        colorscale: colorScaleView,
        cmin: fullScaleValueMin,
        cmax: fullScaleValueMax,
        opacity: 1,
      },
      // projection: { x: { show: true }, y: { show: true } },
    };

    const fittedSurfaceData: Partial<ScatterData> = {
      // name: 'fitted surface',
      x: scaledFittedPoints.x,
      y: scaledFittedPoints.y,
      z: scaledFittedPoints.z,
      mode: 'markers',
      marker: {
        size: 4,
        color: 'lightgray',
      },
      type: 'scatter3d',
    };

    const plotLayout: Partial<Layout> = {
      margin: {
        l: 0,
        r: 0,
        b: 0,
        t: 0,
      },
      autosize: true,
      scene: {
        aspectmode: 'manual',
        aspectratio: {
          x: 1,
          y: 1,
          z: Z_ASPECT,
        },
        xaxis: {
          nticks: 9,
          range: [0, coordMax],
        },
        yaxis: {
          nticks: 9,
          range: [0, coordMax],
        },
        zaxis: {
          nticks: 10,
          range: [0, coordMax * Z_ASPECT],
        },
      },
      showlegend: false,
    };

    return {
      layout: plotLayout,
      data: [fittedSurfaceData, plotData],
    };
  }, [
    colorScaleView,
    fullScaleValueMin,
    fullScaleValueMax,
    scaledPoints,
    scaledFittedPoints,
    coordMax,
    unitString,
  ]);

  const crossSectionXZ = useMemo(() => {
    return extractCrossSectionData('X');
  }, [extractCrossSectionData]);

  const crossSectionYZ = useMemo(() => {
    return extractCrossSectionData('Y');
  }, [extractCrossSectionData]);

  const [crossSectionPlane, setCrossSectionPlane] = useState<CrossSectionPlaneEnum>(
    CrossSectionPlaneEnum.XZ
  );

  return (
    <Container>
      <UnlockCrossSectionToggle
        crossSectionIsUnlocked={crossSectionIsUnlocked}
        setCrossSectionIsUnlocked={setCrossSectionIsUnlocked}
      />

      <Container>
        <Plot
          data={data}
          layout={layout}
          config={CONFIG}
          style={{ width: '100%', height: '100%', marginTop: '6px' }}
          useResizeHandler
          onClick={handleClick}
        />
      </Container>

      <CrossSectionPlaneSelector
        crossSectionPlane={crossSectionPlane}
        setCrossSectionPlane={setCrossSectionPlane}
      />

      <Container>
        {crossSectionPlane == CrossSectionPlaneEnum.XZ && crossSectionXZ && (
          <Container>
            <Plot
              {...crossSectionXZ}
              config={CONFIG}
              style={{ width: '100%', height: '100%', marginTop: '6px' }}
              useResizeHandler
            />
          </Container>
        )}
        {crossSectionPlane == CrossSectionPlaneEnum.YZ && crossSectionYZ && (
          <Container>
            <Plot
              {...crossSectionYZ}
              config={CONFIG}
              style={{ width: '100%', height: '100%' }}
              useResizeHandler
            />
          </Container>
        )}
      </Container>
    </Container>
  );
};

export const Graph = memo(GraphBase);
