import { useCallback, useRef } from 'react';
import {
  XYPosition,
  NodeChange,
  Node as ReactFlowNode,
  applyNodeChanges,
  NodePositionChange,
  EdgeChange,
  applyEdgeChanges,
} from '@xyflow/react';
import { CanvasEdgeConnectionType, CanvasElementType } from '../canvas-types';
import { snapPositionToGrid } from '../helpers/gridPositioningUtils';
import { SweepCanvasRfNode } from '../canvas-types/nodeTypesData';
import { useOnRegularNodePositionChange } from './useOnRegularNodePositionChange';
import { useOnGroupNodePositionChange } from './useOnGroupNodePositionChange';
import { createObjectOperations } from './objectOperations';
import { useSweepCanvasState } from '../internal-context/CanvasStateContext';
import { RFEdgeFloatingEdge } from '../edges';
import { useSweepCanvasPropsCtx } from '../internal-context/SweepCanvasPropsCtx';

export type NodeTransformations = {
  extraNodePositionChanges: NodePositionChange[];
};

export interface MovingNodesORiginalPositionRef {
  [nodeId: string]: { originalPosition: XYPosition };
}

export const useOnNodeChange = () => {
  const {
    setCanvasNodes: setNodes,
    setCanvasEdges: setEdges,
    getCanvasNodes: getNodes,
  } = useSweepCanvasState();
  const movingNodesRef = useRef<MovingNodesORiginalPositionRef>({});

  // Records the initial position of a node when it starts moving
  const maybeRecordInitialPosition = useCallback((node?: ReactFlowNode) => {
    if (node && !movingNodesRef.current[node.id]) {
      movingNodesRef.current[node.id] = { originalPosition: node.position };
    }
  }, []);

  const { onRegularNodePositionChange } = useOnRegularNodePositionChange(movingNodesRef);
  const { onGroupNodePositionChange } = useOnGroupNodePositionChange(movingNodesRef);
  const { onSweepNodesChange } = useSweepCanvasPropsCtx();

  const onNodesChange = useCallback(
    (nodeChanges: NodeChange<SweepCanvasRfNode>[]) => {
      const positionTransformations: NodePositionChange[] = nodeChanges.filter(
        (nodeChange) => nodeChange.type === 'position',
      );
      if (!positionTransformations.length) {
        setNodes((nodes) => {
          const appliedNodes = applyNodeChanges(nodeChanges, nodes);
          return appliedNodes;
        });
        return;
      }

      const extraNodePositionChanges: NodePositionChange[] = [];

      const nodeTransformations: NodeTransformations = {
        extraNodePositionChanges,
      };

      const nodes = getNodes();

      const nodeOperations = createObjectOperations(nodes);
      const { getObject: getNode } = nodeOperations;

      positionTransformations.forEach((nodeChange) => {
        const node = getNode(nodeChange.id);

        maybeRecordInitialPosition(node);
        if (node && nodeChange.position) {
          if (!nodeChange.dragging) {
            nodeChange.position = snapPositionToGrid(nodeChange.position);
          }

          switch (node.type) {
            case CanvasElementType.REGULAR: {
              const changes = onRegularNodePositionChange({
                nodeChange,
                node,
                nodeTransformations,
                nodeOperations,
              });
              if (changes?.length) {
                onSweepNodesChange?.(changes);
              }
              break;
            }
            case CanvasElementType.GROUP_LABEL:
            case CanvasElementType.GROUP: {
              const changes = onGroupNodePositionChange({
                nodeChange,
                node,
                nodeTransformations,
                nodeOperations,
              });
              if (changes?.length) {
                onSweepNodesChange?.(changes);
              }
              break;
            }
            default:
              break;
          }
        }
      });

      const appliedNodes = applyNodeChanges(
        [...nodeChanges, ...extraNodePositionChanges],
        nodeOperations.getObjects(),
      );
      setNodes(appliedNodes);

      setEdges((edges) => {
        const allDraggingNodesIds = [...positionTransformations, ...extraNodePositionChanges]
          .filter((nodeChange) => nodeChange.dragging)
          .map((nodeChange) => nodeChange.id);

        return edges.map((edge) => {
          if (!edge.data) {
            return edge;
          }
          const connectedNodesDragging =
            allDraggingNodesIds.includes(edge.source) || allDraggingNodesIds.includes(edge.target);

          if (connectedNodesDragging) {
            return {
              ...edge,
              data: {
                ...edge.data,
                connectedNodesDragging,
                connectionType: CanvasEdgeConnectionType.Bezier,
              },
            };
          } else {
            if (edge.data.connectedNodesDragging) {
              return {
                ...edge,
                data: {
                  ...edge.data,
                  connectedNodesDragging: false,
                },
              };
            }
          }
          return edge;
        });
      });
    },
    [
      getNodes,
      setNodes,
      setEdges,
      maybeRecordInitialPosition,
      onRegularNodePositionChange,
      onSweepNodesChange,
      onGroupNodePositionChange,
    ],
  );

  const onEdgesChange = useCallback(
    (edgeChanges: EdgeChange<RFEdgeFloatingEdge>[]) => {
      setEdges((edges) => applyEdgeChanges(edgeChanges, edges));
    },
    [setEdges],
  );

  return {
    onNodesChange,
    onEdgesChange,
  };
};
