import { Margin } from '@nivo/core';
import {
  ResponsiveScatterPlotCanvas,
  ScatterPlotDatum,
  ScatterPlotLayerProps,
  ScatterPlotNodeData,
  ScatterPlotRawSerie,
} from '@nivo/scatterplot';
import { ThemeContext } from 'styled-components';
import { Theme } from '@nivo/core';
import { useContext } from 'react';
import { getDatavizTheme } from '@plotting/single-plot-view/plot-panel/plot.themes';
import {
  DEFAULT_TITLE_SIZE,
  DEFAULT_X_AXIS_STYLE,
  DEFAULT_Y_AXIS_STYLE,
} from '@dataviz/constants';
import { getVolcanoTooltip } from './getVolcanoTooltip';

const DEFAULT_MARGIN = {
  top: 60,
  right: 140,
  bottom: 90,
  left: 116,
};
const DEFAULT_PVALUE_THRESHOLD = 0.05;
const DEFAULT_FOLD_CHANGE_THRESHOLD = 2;

type VolcanoPlotProps = {
  data: ScatterPlotRawSerie<ScatterPlotDatum & { geneName: string }>[];
  margin?: Partial<Margin>;
  title?: string;
  titleSize?: number;
  datavizTheme?: Theme;
  pValueThreshold?: number;
  foldChangeThreshold?: number;
};

export const VolcanoPlot = ({
  data,
  title,
  titleSize = DEFAULT_TITLE_SIZE,
  datavizTheme,
  pValueThreshold = DEFAULT_PVALUE_THRESHOLD,
  foldChangeThreshold = DEFAULT_FOLD_CHANGE_THRESHOLD,
  margin = DEFAULT_MARGIN,
}: VolcanoPlotProps) => {
  const { palette } = useContext(ThemeContext);
  const finalDatavizTheme = datavizTheme ?? getDatavizTheme({}, palette);

  const drawPlotTitleLayer = (
    ctx: CanvasRenderingContext2D,
    props: ScatterPlotLayerProps<ScatterPlotDatum>
  ) => {
    ctx.save();
    ctx.fillStyle = palette.textPrimary;
    ctx.font = titleSize + 'px ' + finalDatavizTheme.fontFamily;
    ctx.textAlign = 'left';
    if (title) {
      ctx.fillText(title, props.innerWidth / 2, -margin.top / 2);
    }
    ctx.restore();
  };

  const drawAblineLayer = (
    ctx: CanvasRenderingContext2D,
    props: ScatterPlotLayerProps<ScatterPlotDatum>
  ) => {
    ctx.save();

    ctx.strokeStyle = palette.textPrimary;
    ctx.setLineDash([5, 15]);

    // Horizontal line for the p-value threshold
    ctx.beginPath();
    ctx.moveTo(0, props.yScale(-Math.log10(pValueThreshold)));
    ctx.lineTo(props.innerWidth, props.yScale(-Math.log10(pValueThreshold)));
    ctx.stroke();

    // First vertical line, for the negative fold change
    ctx.beginPath();
    ctx.moveTo(props.xScale(-foldChangeThreshold), 0);
    ctx.lineTo(props.xScale(-foldChangeThreshold), props.innerHeight);
    ctx.stroke();

    // positive fold change threshold
    ctx.beginPath();
    ctx.moveTo(props.xScale(foldChangeThreshold), 0);
    ctx.lineTo(props.xScale(foldChangeThreshold), props.innerHeight);
    ctx.stroke();

    ctx.restore();
  };

  const renderNode = (
    ctx: CanvasRenderingContext2D,
    node: ScatterPlotNodeData<ScatterPlotDatum>
  ) => {
    const isHighlighted =
      Number(node.data.y) > -Math.log10(pValueThreshold) &&
      (Number(node.data.x) > foldChangeThreshold ||
        Number(node.data.x) < -foldChangeThreshold);

    ctx.beginPath();
    ctx.arc(node.x, node.y, node.size / 2, 0, 2 * Math.PI);
    ctx.fillStyle = isHighlighted ? palette.accentPrimary : 'grey';
    ctx.fill();
  };

  const tooltip = getVolcanoTooltip(finalDatavizTheme);

  return (
    <ResponsiveScatterPlotCanvas
      data={data}
      renderNode={renderNode}
      margin={margin || DEFAULT_MARGIN}
      theme={finalDatavizTheme}
      xScale={{ type: 'linear', min: -4, max: 4 }}
      layers={[drawAblineLayer, 'grid', 'axes', 'nodes', drawPlotTitleLayer]}
      axisBottom={{ ...DEFAULT_X_AXIS_STYLE, legend: 'Log2 Fold Change' }}
      axisLeft={{ ...DEFAULT_Y_AXIS_STYLE, legend: '-log10(pvalue)' }}
      tooltip={tooltip}
      useMesh={false}
    />
  );
};
