import { Margin } from '@nivo/core';
import {
  ResponsiveScatterPlot,
  ScatterPlotCustomSvgLayer,
  ScatterPlotDatum,
  ScatterPlotRawSerie,
} from '@nivo/scatterplot';
import { ThemeContext } from 'styled-components';
import { ScaleLinearSpec, ScaleLogSpec } from '@nivo/scales/dist/types/types';
import { Theme } from '@nivo/core';
import { useContext } from 'react';
import { getDatavizTheme } from '@plotting/single-plot-view/plot-panel/plot.themes';
import {
  DEFAULT_LEGEND_CONFIG,
  DEFAULT_TITLE_SIZE,
  DEFAULT_X_AXIS_STYLE,
  DEFAULT_Y_AXIS_STYLE,
} from '@dataviz/constants';
import { getScatterplotTooltip } from './getScatterplotTooltip';
import { scaleOrdinal, scaleSqrt } from 'd3-scale';
import { extent } from 'd3-array';
import { runLinearRegression } from './runLinearRegression';
import { COLORS } from '@utils/scales/color/ColorSchemes';
import { getScatterplotNode } from './getScatterplotNode';
import { KeyColor } from '@plotting/single-plot-view/plot.types';
import { getScatterplotRegressionLayer } from './getScatterplotRegressionLayer';

const DEFAULT_MARGIN = {
  top: 60,
  right: 160,
  bottom: 90,
  left: 116,
};

export const DEFAULT_CIRCLE_SIZE = 6;

type ScatterPlotProps = {
  data: ScatterPlotRawSerie<
    ScatterPlotDatum & { size: number; color?: string }
  >[];
  title?: string;
  titleSize?: number;
  xAxisName?: string;
  xAxisScale?: ScaleLinearSpec | ScaleLogSpec;
  yAxisName?: string;
  yAxisScale?: ScaleLinearSpec | ScaleLogSpec;
  datavizTheme?: Theme;
  margin?: Margin;
  circleSize?: number | [number, number];
  isLinearRegressionEnabled?: boolean;
  isLegendEnabled?: boolean;
  colorConfig?: KeyColor[];
};

export const ScatterPlot = ({
  data,
  title,
  titleSize,
  xAxisName,
  xAxisScale,
  yAxisName,
  yAxisScale,
  datavizTheme,
  margin = DEFAULT_MARGIN,
  circleSize,
  isLinearRegressionEnabled,
  isLegendEnabled,
  colorConfig,
}: ScatterPlotProps) => {
  // Dataviz theme can be passed as a prop (for the screenshotting feature for instance) OR built here
  const { palette } = useContext(ThemeContext);
  const finalDatavizTheme = datavizTheme ?? getDatavizTheme({}, palette);

  const plotTitleStyle = {
    fontFamily: finalDatavizTheme.fontFamily,
    fontSize: titleSize || DEFAULT_TITLE_SIZE,
    fill: finalDatavizTheme.textColor,
    textAnchor: 'middle',
  } as const;

  const plotTitle = ({ innerWidth }) => {
    return (
      <text x={innerWidth / 2} y={-margin.top / 2} style={plotTitleStyle}>
        {title}
      </text>
    );
  };

  const hasMultipleSeries = data.length > 1;

  //
  // SIZES
  //
  const allSizes = data.flatMap((category) =>
    category.data.map((point) => point.size)
  );
  const sizeDomain = extent(allSizes);
  const sizeRange =
    typeof circleSize === 'undefined'
      ? [DEFAULT_CIRCLE_SIZE, DEFAULT_CIRCLE_SIZE]
      : typeof circleSize === 'number'
      ? [circleSize, circleSize]
      : circleSize;
  const sizeScale = scaleSqrt().domain(sizeDomain).range(sizeRange);

  //
  // COLORS
  //
  const groups = [...new Set(data.map((d) => d.id))];
  const defaultColors = COLORS.find((col) => col.id === 'aseda');
  const defaultColorScale = scaleOrdinal<string>()
    .domain(groups.map((g) => String(g)))
    .range(defaultColors.scheme);
  const getColorFromGroup = (name: string) => {
    if (colorConfig) {
      return (
        colorConfig.find((config) => config.id === name)?.color ||
        defaultColorScale(name)
      );
    }
    return defaultColorScale(name);
  };

  //
  // LINEAR REGRESSION
  //
  let linearRegressionLayer: ScatterPlotCustomSvgLayer<
    ScatterPlotDatum & { size: number }
  > = () => {
    return null;
  };

  if (isLinearRegressionEnabled) {
    // run 1 linear regression per group in the dataset
    const linearRegressionResults = data.map((grp) => {
      return {
        id: grp.id,
        regression: runLinearRegression(grp.data),
      };
    });

    linearRegressionLayer = getScatterplotRegressionLayer(
      linearRegressionResults,
      getColorFromGroup
    );
  }

  return (
    <ResponsiveScatterPlot<ScatterPlotDatum & { size: number }>
      data={data}
      margin={margin || DEFAULT_MARGIN}
      theme={finalDatavizTheme}
      layers={[
        plotTitle,
        'grid',
        'axes',
        'nodes',
        'markers',
        linearRegressionLayer,
      ]}
      xScale={xAxisScale}
      yScale={yAxisScale}
      axisBottom={{ ...DEFAULT_X_AXIS_STYLE, legend: xAxisName }}
      axisLeft={{ ...DEFAULT_Y_AXIS_STYLE, legend: yAxisName }}
      enableGridX={true}
      enableGridY={true}
      animate={false}
      tooltip={getScatterplotTooltip(finalDatavizTheme)}
      legends={
        hasMultipleSeries && isLegendEnabled
          ? [DEFAULT_LEGEND_CONFIG]
          : undefined
      }
      nodeComponent={getScatterplotNode}
      nodeSize={(node) => {
        return node.data.size
          ? sizeScale(node.data.size)
          : typeof circleSize === 'number'
          ? circleSize
          : DEFAULT_CIRCLE_SIZE;
      }}
      colors={(group) => {
        return getColorFromGroup(String(group.serieId));
      }}
      useMesh={false}
    />
  );
};
