import * as React from "react";
import { ScaleMinimum, XUnit } from "Types/CartesianSpace";
import { CSSProperties, FunctionComponent, useState, useEffect, useRef } from "react";
import GraphAxis from "Components/Graphs/GraphAxis";
import LinearRegressionTrendLine from "Components/Graphs/LinearRegressionTrendLine";
import Pixels from "Types/Pixels";
import ScatterPoints from "Components/Graphs/Scatter/ScatterPoints";
import TrendingDownOutlinedIcon from "@mui/icons-material/TrendingDownOutlined";
import { CircularProgress, ToggleButton } from "@mui/material";
import { DateCoordinate } from "Types/DateCoordinate";
import { PhotoSizeSelectSmall, PanTool, ZoomOutMap, DeleteForever } from "@mui/icons-material";
import { interpolateRgb, D3BrushEvent, ZoomTransform, D3ZoomEvent } from "d3";
import useResizeObserver from "use-resize-observer";
import WarningCriticalAreas from "../Line/WarningCriticalAreas";
import * as d3 from "d3";
import IconButton from "../../Buttons/IconButton";
import FlexibleAndFixedPanel from "../../../Panels/FlexibleAndFixedPanel";
import DateTickFormat from "./DateTickFormat";
import { isNotNull, isNotNumberArr } from "../D3SupportFunctions";
import ErrorPoints from "./ErrorPoints";
import NoDataMessage from "../../../Panels/Composite/Primary/NoDataMessage";

const paddingToPointsRatioOnXAxis = 0.05;

interface ScatterPlotProps {
    aboveSegment?: React.ReactChild;
    plots: {
        label: string;
        points: DateCoordinate[];
        outliers?: DateCoordinate[];
        highlightedPoint?: DateCoordinate | null;
        errorDates?: DateCoordinate[];
        onEditClick?(): void;
    }[];
    startXAxisScaleFrom: ScaleMinimum;
    startYAxisScaleFrom: ScaleMinimum;
    xAxisLabel: string;
    yAxisLabel: string;
    onHoveredPointChange?(point: DateCoordinate | null): void;
    onPointDoubleClicked?(point: DateCoordinate): void;
    onPointRightClicked?(point: DateCoordinate): void;
    onPointsDeleteRequested?(points: DateCoordinate[]): void;
    showDeleteButton?: boolean;
    pointsRadius: Pixels;
    styles?: {
        container?: CSSProperties;
        points?: CSSProperties;
        outliers?: CSSProperties;
        highlight?: CSSProperties;
        trendLine?: CSSProperties;
    };
    warningValue?: number;
    criticalValue?: number;
    yUnits?: string;
    resetIfThisChanges?: number | string;
    showLegend?: boolean;
    showRecalculatingMessage?: boolean;
    showLocalTrendButton?: boolean;
}

const ScatterPlotGraph: FunctionComponent<ScatterPlotProps> = props => {

    const { plots, pointsRadius,
        styles, xAxisLabel, yAxisLabel, warningValue, criticalValue, yUnits, resetIfThisChanges, onPointDoubleClicked,
        onPointRightClicked, onPointsDeleteRequested, showDeleteButton, showRecalculatingMessage, aboveSegment,
        showLocalTrendButton = true } = props;

    const { ref, ...sizeU } = useResizeObserver<HTMLDivElement>({ "box": "content-box" });
    const size = { width: sizeU.width || 1, height: sizeU.height || 1 };

    const [transform, setTransform] = useState<ZoomTransform | null>(null);

    const margin = {
        top: 20,
        right: 20,
        bottom: 25,
        left: 60
    };

    const onHoveredPointChange = (point: DateCoordinate | null) => {
        const { onHoveredPointChange } = props;
        if (onHoveredPointChange) onHoveredPointChange(point);
    }

    const [selectedTool, setSelectedTool] = useState<"drawBox" | "trend" | "pan" | "delete">("pan");

    const [trendPointLines, setTrendPointLines] = useState<DateCoordinate[][]>([]);

    const [selectedRect, setSelectedRect] = React.useState<{ x1: XUnit; x2: XUnit; y1: number; y2: number } | null>(null);

    const svgRef = useRef<SVGSVGElement>(null);
    const gRef = useRef<SVGGElement>(null);

    const getLineColour = (index: number, faded = false) => {
        const colours = ["red", "blue", "darkOrange", "green", "purple", "deepPink", "black", "darkCyan", "goldenRod",
            "indigo", "olive", "saddleBrown", "slateGrey", "navy", "cornflowerBlue"];
        const colour = colours[index % colours.length];
        const fadedColour = interpolateRgb(colour, "white")(0.7);

        return faded ? fadedColour : colour;
    };

    // Reset the state when e.g. the user changes sensor
    useEffect(() => {
        setTrendPointLines([]);
        setSelectedTool("pan");
        setSelectedRect(null);
        setTransform(null);
    }, [resetIfThisChanges]);

    if (plots.length === 0) {
        return (null);
    }

    const containerWidth = size.width;
    const containerHeight = size.height;

    const horizontalMargins = margin.left + margin.right;
    const verticalMargins = margin.top + margin.bottom;

    const width = containerWidth - horizontalMargins;
    const height = containerHeight - verticalMargins;

    const points = plots.reduce<DateCoordinate[]>((accumulator, value) => accumulator.concat(value.points), []).sort((a, b) => a.x.getTime() - b.x.getTime());
    const outliers = plots.reduce<DateCoordinate[]>((accumulator, value) => accumulator.concat(value.outliers || []), []).sort((a, b) => a.x.getTime() - b.x.getTime());
    const errorDates = plots.reduce<DateCoordinate[]>((accumulator, value) => accumulator.concat(value.errorDates || []), []).sort((a, b) => a.x.getTime() - b.x.getTime());

    const pointsAndOutliers = [...points, ...(outliers ?? [])];
    const pointsOutliersAndErrors = [...pointsAndOutliers, ...(errorDates || [])];

    const xMinimum = d3.min(pointsOutliersAndErrors, p => p.x) || new Date();
    const yMinimum = d3.min(pointsAndOutliers, p => p.y) as number; // errors aren't included in the yMinimum/yMaximum because their values are 0

    const xMaximum = d3.max(pointsOutliersAndErrors, p => p.x) || new Date();
    const yMaximum = d3.max(pointsAndOutliers, p => p.y) as number;

    const xMinimumTicks = xMinimum.getTime();
    const xMaximumTicks = xMaximum.getTime();

    const range = xMinimumTicks - xMaximumTicks;
    const padding = range * paddingToPointsRatioOnXAxis;

    const xMinimumWithPadding = new Date(xMinimumTicks + (padding / 2));
    const xMaximumWithPadding = new Date(xMaximumTicks - (padding / 2));

    const xScalePre = d3
        .scaleTime()
        .domain([selectedRect?.x1 || xMinimumWithPadding, selectedRect?.x2 || xMaximumWithPadding])
        .nice()
        .range([0, width]);

    const xScale = transform ? transform.rescaleX(xScalePre) : xScalePre;

    const yPadding = 4; // 2mm above and 2mm below

    const yMinimumWithPadding = yMinimum - (yPadding / 2);
    const yMaximumWithPadding = yMaximum + (yPadding / 2);

    const yScalePre = d3
        .scaleLinear()
        .domain([selectedRect?.y2 || yMinimumWithPadding, selectedRect?.y1 || yMaximumWithPadding])
        .nice()
        .range([height, 0]);

    const yScale = transform ? transform.rescaleY(yScalePre) : yScalePre;

    const zoom = d3.zoom<SVGSVGElement, unknown>().on("zoom", (ev: D3ZoomEvent<SVGSVGElement, unknown>) => {
        setTransform(ev.transform);
    });

    const brush = d3.brush()
        .on("end", (ev: D3BrushEvent<SVGSVGElement>) => {
            const selection = ev.selection;
            if (isNotNull(selection) && isNotNumberArr(selection)) {
                setTransform(null);

                const theRect = {
                    x1: xScale.invert(selection[0][0] - margin.left),
                    x2: xScale.invert(selection[1][0] - margin.left),
                    y1: yScale.invert(selection[0][1] - margin.top),
                    y2: yScale.invert(selection[1][1] - margin.top)
                };

                setSelectedRect(theRect);

                if (gRef.current) {
                    d3.select(gRef.current).call(brush.move, null);
                } else {
                    console.error("svgRef.current is not defined");
                }
            } else {
                console.error("selection not correct type", selection);
            }
        });

    const localTrendBrush = d3.brush()
        .on("end", (ev: D3BrushEvent<unknown>) => {
            const selection = ev.selection;
            if (isNotNull(selection) && isNotNumberArr(selection)) {

                const trendLines = plots.map(plot => {
                    return plot.points.filter(p =>
                        p.x >= xScale.invert(selection[0][0] - margin.left) &&
                        p.x <= xScale.invert(selection[1][0] - margin.left) &&
                        p.y >= yScale.invert(selection[1][1] - margin.top) &&
                        p.y <= yScale.invert(selection[0][1] - margin.top)
                    );
                }).filter(line => line.length >= 2);

                const psx = points.filter(p =>
                    p.x >= xScale.invert(selection[0][0] - margin.left) &&
                    p.x <= xScale.invert(selection[1][0] - margin.left) &&
                    p.y >= yScale.invert(selection[1][1] - margin.top) &&
                    p.y <= yScale.invert(selection[0][1] - margin.top));

                // Filer lines to only those that don't contain one of our lines (two or more points)
                const filteredLines = trendPointLines.filter(trendLine => {
                    const noP = trendLine.filter(point => psx.map(a => a.x).includes(point.x)).length;

                    return noP <= 1;
                });

                setTrendPointLines([...filteredLines, ...trendLines]);

                if (gRef.current) {
                    d3.select(gRef.current).call(brush.move, null);
                } else {
                    console.error("svgRef.current is not defined");
                }
            } else {
                console.error("selection not correct type", selection);
            }
        });

    const deletionBrush = d3.brush()
        .on("end", (ev: D3BrushEvent<unknown>) => {
            const selection = ev.selection;
            if (isNotNull(selection) && isNotNumberArr(selection)) {
                const ySelectionTop = yScale.invert(selection[0][1] - margin.top);
                const ySelectionBottom = yScale.invert(selection[1][1] - margin.top);
                const selectionBelowXAxis = (selection[1][1] - margin.top) >= yScale.range()[0];

                const psx = points.filter(p =>
                    p.x >= xScale.invert(selection[0][0] - margin.left) &&
                    p.x <= xScale.invert(selection[1][0] - margin.left) &&
                    p.y >= ySelectionBottom && p.y <= ySelectionTop);

                const filteredErrors = selectionBelowXAxis ?
                    errorDates.filter(p =>
                        p.x >= xScale.invert(selection[0][0] - margin.left) &&
                        p.x <= xScale.invert(selection[1][0] - margin.left))
                    : [];

                if (onPointsDeleteRequested) {
                    onPointsDeleteRequested([...psx, ...filteredErrors]);
                }

                if (gRef.current) {
                    d3.select(gRef.current).call(brush.move, null);
                } else {
                    console.error("svgRef.current is not defined");
                }
            }
        });

    useEffect(() => {
        if (gRef.current && svgRef.current) {
            if (selectedTool === "pan") {
                d3.select(svgRef.current).call(zoom);
                return () => {
                    d3.select(svgRef.current).on(".zoom", null);
                };
            }

            return () => null;
        } else {
            console.error("svgRef.current/gRef.current is not defined");
            return () => null;
        }
    }, [selectedTool, selectedRect, trendPointLines, svgRef.current !== null, gRef.current !== null]);
    
    useEffect(() => {
        if (gRef.current && svgRef.current) {
            console.log("setting selected tool", selectedTool);
            if (selectedTool === "drawBox") {
                d3.select(gRef.current).call(brush);
                return () => {
                    d3.select(gRef.current).on(".brush", null);
                    d3.select(".overlay").remove();
                };
            }
            else if (selectedTool === "trend") {
                d3.select(gRef.current).call(localTrendBrush);
                return () => {
                    d3.select(gRef.current).on(".brush", null);
                    d3.select(".overlay").remove();
                };
            }
            else if (selectedTool === "delete") {
                d3.select(gRef.current).call(deletionBrush);
                return () => {
                    d3.select(gRef.current).on(".brush", null);
                    d3.select(".overlay").remove();
                };
            }

            return () => null;
        } else {
            console.error("svgRef.current/gRef.current is not defined");
            return () => null;
        }
        
    }, [selectedTool, selectedRect, trendPointLines, svgRef.current !== null, gRef.current !== null, xScale, yScale]);

    if (pointsOutliersAndErrors.length === 0 && errorDates.length === 0) return <NoDataMessage>No data to show</NoDataMessage>;

    const graphTranslation = `translate(${margin.left}, ${margin.top})`;

    const getLineEarliestTime = (coords: DateCoordinate[]) => Math.min(...coords.map(c => c.x.getTime()));

    const trendPointLinesSorted = trendPointLines.sort((a, b) => getLineEarliestTime(a) - getLineEarliestTime(b));
    
    const xAxisGenerator = d3.axisBottom<Date>(xScale)
        .tickFormat(DateTickFormat)
        .ticks(5);

    const yAxisGenerator = d3.axisLeft(yScale);

    // The clip-path code using an id will need to be re-worked if two a-scan graphs are ever shown side-by-side
    const clipPathId = `graphClip${Math.round(Math.random() * 100000)}`;
    return (
        <FlexibleAndFixedPanel
            topFixedContent={<>
                {aboveSegment}
                {showRecalculatingMessage ? <><CircularProgress size={18} /> Recalculating...&nbsp;&nbsp;</> : null}
                <>{props.showLegend ? plots.map((plot, index) =>
                    <span
                        key={`text-${index}`}
                        style={{ color: getLineColour(index), display: "inline-block", fontSize: 16, cursor: "pointer" }}
                        onClick={() => {
                            if (plot.onEditClick) plot.onEditClick();
                        }}
                    >
                        {plot.label.trim()} &nbsp;
                    </span>
                ) : null}</>
            </>}
            flexibleContent={
                <div
                    style={{ position: "relative", height: "100%", width: "100%", backgroundColor: "white" }}
                    ref={ref}
                    onWheel={ev => {
                        const effectiveTransform = transform != null ? transform : d3.zoomIdentity;
                        const scaleFactor = ev.deltaY < 0 ? 1.1 : 1 / 1.1;

                        const rect = (ev.target as HTMLElement).getBoundingClientRect();

                        const offsetX = ev.clientX - rect.left - margin.left;
                        const offsetY = ev.clientY - rect.top - margin.top;

                        setTransform(effectiveTransform.translate(offsetX, offsetY).scale(scaleFactor).translate(-offsetX, -offsetY));
                    }}
                >
                    <svg width="100%" height="100%" style={{ ...styles?.container, display: "block", position: "absolute", left: 0, top: 0 }} ref={svgRef}>
                        <g width={width} height={height} transform={graphTranslation}>
                            <WarningCriticalAreas
                                xScale={xScale}
                                yScale={yScale}
                                warningPosition={warningValue}
                                criticalPosition={criticalValue}
                            />

                            <GraphAxis axis="x" axisGenerator={xAxisGenerator} containerSize={size} graphMargin={margin} />
                            <GraphAxis axis="y" axisGenerator={yAxisGenerator} containerSize={size} graphMargin={margin} label={yAxisLabel} />

                            <clipPath id={clipPathId}>
                                <rect x={0} y={0} width={Math.max(width, 0)} height={Math.max(height, 0)} />
                            </clipPath>

                            <g clipPath={`url(#${clipPathId})`}>
                                {plots.map((plot, index) =>
                                    <g key={index}>
                                        {selectedTool !== "trend" && plot.points.length !== 0 ?
                                            plot.points.map((_point, iindex) =>
                                                iindex < plot.points.length - 1 ?
                                                    <line
                                                        key={`line-${index}-${iindex}`}
                                                        x1={xScale(plot.points[iindex].x)}
                                                        x2={xScale(plot.points[iindex + 1].x)}
                                                        y1={yScale(plot.points[iindex].y)}
                                                        y2={yScale(plot.points[iindex + 1].y)}
                                                        stroke={getLineColour(index, true)}
                                                    />
                                                    : null)
                                            : null
                                        }
                                        <ScatterPoints
                                            key={`points-${index}`}
                                            points={plot.points}
                                            outliers={plot.outliers}
                                            highlightedPoint={plot.highlightedPoint}
                                            xScale={xScale}
                                            yScale={yScale}
                                            onHoveredPointChange={point => onHoveredPointChange(point)}
                                            onPointDoubleClicked={point => onPointDoubleClicked ? onPointDoubleClicked(point) : null}
                                            onPointRightClicked={point => onPointRightClicked ? onPointRightClicked(point) : null}
                                            radius={pointsRadius}
                                            colour={getLineColour(index, false)}
                                            colourLight={getLineColour(index, true)}
                                        />
                                    </g>
                                )}
                            </g>

                            {selectedTool === "trend" ? trendPointLinesSorted.map((p, index) =>
                                <>
                                    <LinearRegressionTrendLine
                                        key={index}
                                        points={p}
                                        xScale={xScale}
                                        yScale={yScale}
                                        style={{
                                            pointerEvents: "none",
                                            stroke: getLineColour(index),
                                            strokeWidth: 2
                                        }}
                                        extrapolate={false}
                                        gradientAnnotation
                                        yUnits={yUnits}
                                        textColour={getLineColour(index)}
                                    />

                                    {plots.map((plot, index) =>
                                        <>
                                            {selectedTool !== "trend" && plot.points.length !== 0 ?
                                                plot.points.map((_point, iindex) =>
                                                    iindex < plot.points.length - 1 ?
                                                        <line
                                                            key={`line-${index}-${iindex}`}
                                                            x1={xScale(plot.points[iindex].x)}
                                                            x2={xScale(plot.points[iindex + 1].x)}
                                                            y1={yScale(plot.points[iindex].y)}
                                                            y2={yScale(plot.points[iindex + 1].y)}
                                                            stroke={getLineColour(index, true)}
                                                        />
                                                        : null)
                                                : null
                                            }
                                            <ScatterPoints
                                                key={`points-${index}`}
                                                points={plot.points}
                                                outliers={plot.outliers}
                                                highlightedPoint={plot.highlightedPoint}
                                                xScale={xScale}
                                                yScale={yScale}
                                                onHoveredPointChange={point => onHoveredPointChange(point)}
                                                onPointDoubleClicked={point => onPointDoubleClicked ? onPointDoubleClicked(point) : null}
                                                onPointRightClicked={point => onPointRightClicked ? onPointRightClicked(point) : null}
                                                radius={pointsRadius}
                                                colour={getLineColour(index, false)}
                                                colourLight={getLineColour(index, true)}
                                            />
                                        </>
                                    )}
                                </>
                            ) : null}

                            {plots.map((plot, index) =>
                                <ErrorPoints
                                    key={`errorPoints-${index}`}
                                    errorPoints={plot.errorDates}
                                    highlightedPoint={plot.highlightedPoint}
                                    xScale={xScale}
                                    yScale={yScale}
                                    colour={getLineColour(index, false)}
                                    colourLight={getLineColour(index, true)}
                                    onHoveredPointChange={point => onHoveredPointChange(point)}
                                    onPointDoubleClicked={point => onPointDoubleClicked ? onPointDoubleClicked(point) : null}
                                    onPointRightClicked={point => onPointRightClicked ? onPointRightClicked(point) : null}
                                />
                            )}
                        </g>
                        <g ref={gRef} className="brush"></g>
                    </svg>
                </div>
            }
            bottomFixedContent={
                <div style={{ padding: 5 }}>
                    <IconButton
                        icon={<ZoomOutMap />}
                        title="Reset Zoom"
                        onClick={() => {
                            if (svgRef.current) {
                                d3.select(svgRef.current).call(zoom.transform, d3.zoomIdentity);
                                setSelectedRect(null);
                            } // TODO error
                        }}
                    />
                    <ToggleButton
                        key="drawbox"
                        value="x"
                        selected={selectedTool === "drawBox"}
                        onChange={() => setSelectedTool("drawBox")}
                        title="Draw box to zoom"
                    >
                        <PhotoSizeSelectSmall />
                    </ToggleButton>
                    <ToggleButton
                        key="pan-tool"
                        value="x"
                        selected={selectedTool === "pan"}
                        onChange={() => setSelectedTool("pan")}
                        title="Pan"
                    >
                        <PanTool />
                    </ToggleButton>
                    {
                        showLocalTrendButton ?
                        <ToggleButton
                            key="toggle"
                            value="x"
                            selected={selectedTool === "trend"}
                            onChange={() => {
                                setSelectedTool("trend");
                                setTrendPointLines([]);
                            }}
                            title="Draw a box to view local trend"
                        >
                            <TrendingDownOutlinedIcon />
                        </ToggleButton>
                        : null
                    }
                    {
                        showDeleteButton ?
                            <ToggleButton
                                key="deleteForever"
                                value="x"
                                selected={selectedTool === "delete"}
                                title="Delete readings in view forever"
                                onChange={() => {
                                    setSelectedTool("delete");
                                }}
                            >
                                <DeleteForever />
                            </ToggleButton>
                            : null
                    }
                    <div style={{ float: "right" }}>
                        {xAxisLabel}
                    </div>
                </div>
            }
        />
    );
}

export default ScatterPlotGraph;
