import {
  closestCenter,
  DndContext,
  DragEndEvent,
  DragOverlay,
  DragStartEvent,
  MouseSensor,
  TouchSensor,
  useSensor,
  useSensors,
} from '@dnd-kit/core';
import { Fragment, useCallback, useState } from 'react';
import { SortableContext, verticalListSortingStrategy } from '@dnd-kit/sortable';
import { TableContainer } from '@mui/material';
import { DataTableRow, DataTableDraggableProps, DataTableVariant } from './TableTypes';
import { StyledTable, StyledTableBody } from './StyledTableComponents';
import { DraggableTableRow } from './DraggableTableRow';

export function DraggableTableBodyRows<TRow extends DataTableRow = any>({
  rows,
  onOrderChange,
  renderRow,
  ...dataTableProps
}: Pick<
  DataTableDraggableProps<TRow>,
  | 'rows'
  | 'columns'
  | 'renderRow'
  | 'sxRowFunction'
  | 'onRowClick'
  | 'actionableButtonsOnHover'
  | 'onOrderChange'
  | 'allowReorder'
> & {
  variant: DataTableVariant;
}) {
  const sensors = useSensors(useSensor(MouseSensor), useSensor(TouchSensor));
  const [draggedItemId, setDraggedItemId] = useState<string | null>(null);
  const draggedRow = rows.find((row) => row.id === draggedItemId);

  const handleDragStart = useCallback((event: DragStartEvent) => {
    setDraggedItemId(event.active.id + '');
  }, []);

  const handleDragEnd = useCallback(
    (event: DragEndEvent) => {
      const { active, over } = event;
      if (active.id !== over?.id) {
        const sourceIndex = rows.findIndex((row) => row.id === active.id);
        const destinationIndex = rows.findIndex((row) => row.id === over?.id);
        onOrderChange?.({ sourceIndex, destinationIndex });
      }
      setDraggedItemId(null);
    },
    [rows, onOrderChange],
  );

  const renderSingleRow = (row: TRow, rowIdx: number, rowKey: string) => {
    const renderedRow =
      renderRow &&
      renderRow({
        rowIdx,
        row,
        rowKey,
      });

    if (renderedRow) {
      return renderedRow;
    }

    return (
      <DraggableTableRow
        key={row.id}
        row={row}
        isOverlay={false}
        isBeingDragged={draggedItemId === row.id}
        {...dataTableProps}
      />
    );
  };

  return (
    <DndContext
      sensors={sensors}
      collisionDetection={closestCenter}
      onDragEnd={handleDragEnd}
      onDragStart={handleDragStart}
    >
      <SortableContext items={rows} strategy={verticalListSortingStrategy}>
        {rows.map((row, rowIdx) => (
          <Fragment key={`wrapper_${row.id}`}>
            {renderSingleRow(row, rowIdx, `base_row_${rowIdx}`)}
            {row?.nestedRows &&
              row.nestedRows.map((row: any, idx: number) =>
                renderSingleRow(row, idx, `nested_row_${row.id}`),
              )}
          </Fragment>
        ))}
      </SortableContext>
      <DragOverlay
        adjustScale
        style={{
          transformOrigin: '0 0 ',
        }}
      >
        {draggedRow && ( //draggedItem should be wrapped with table to have the same layout (columns width) as the static rows
          <TableContainer sx={{ overflow: 'unset', height: '100%' }}>
            <StyledTable>
              <StyledTableBody>
                <DraggableTableRow
                  key={draggedRow.id}
                  row={draggedRow}
                  isOverlay={true}
                  isBeingDragged={true}
                  {...dataTableProps}
                />
              </StyledTableBody>
            </StyledTable>
          </TableContainer>
        )}
      </DragOverlay>
    </DndContext>
  );
}

export default DraggableTableBodyRows;
