import { Fragment, KeyboardEvent, useEffect, useRef } from "react";
import { useTranslation } from "react-i18next";
import {
  Box,
  SxProps,
  Table as MuiTable,
  TableBody,
  TableCell,
  TableContainer,
  TableFooter,
  TableHead,
  TablePagination,
  TableProps as MuiTableProps,
  TableRow,
  TableSortLabel,
  Theme,
  Tooltip,
} from "@mui/material";
import { flexRender, Header, Row, Table } from "@tanstack/react-table";
import { isFunction } from "lodash";
import isEmpty from "lodash/isEmpty";
import ReactTableRow from "components/shared/ReactTable/ReactTableRow";
import { ASCENDING, DESCENDING } from "components/shared/Table/constants";
import tableStyles from "components/shared/Table/Table.styles";
import { Keyboard } from "utils/constants/keyboard.constants";
import { ROWS_PER_PAGE_OPTIONS } from "utils/constants/table.constants";
import SelectionCell from "../Table/components/SelectionCell";
import { getAdjacentElementWithTabIndex } from "../Table/Table.utils";
import { TColumn, TRowAlign, TRowSpanColumn } from "./ReactTable.types";
import { getColumnTitle } from "./ReactTable.utils";
import { TReactTableRowBaseProps } from "./ReactTableRow/ReactTableRow";
import { ReactTableTabs, TReactTableTabs } from "./ReactTableTabs";
import accessibilityStyles from "styles/accessibility.styles";

export type TReactTableInstance<T extends object = {}> = Omit<
  Table<T>,
  "setColumnFilters" | "setGlobalFilter"
>;

export type TReactTableProps<T extends object> = MuiTableProps & {
  tableInstance: TReactTableInstance<T>;
  isHeadless?: boolean;
  hiddenColumns?: string[];
  rowsPerPageOptions?: number[];
  isPaginated: boolean;
  bodyRowHover?: boolean | ((row: Row<T>) => boolean);
  sx?: SxProps<Theme>;
  onRowClick?: (row: Row<T>) => void;
  bodyRowAlign?: TRowAlign;
  tabs?: TReactTableTabs[] | null;
  tabValue?: string;
  EmptyStateComponent?: React.ReactNode;
  onTabChange?: (value?: string) => void;
  rowRenderer?: (props: TReactTableRowBaseProps<T>) => React.ReactNode;
};

const DEFAULT_COLUMN_WIDTH = 150;
const DEFAULT_COLUMN_AMOUNT = 1;
const DEFAULT_ROW_AMOUNT = 1;

const getHeaderStyles = <T extends object>(header: Header<T, unknown>) => {
  const { meta } = header.column.columnDef;
  const headerStyles = meta?.headCellStyles;

  return isFunction(headerStyles) ? headerStyles(header) : headerStyles;
};

const renderHeader = <T extends object>(header: Header<T, unknown>) => {
  const { columnDef } = header.column;
  const headerContent = flexRender(columnDef.header, header.getContext());

  // SROnly styles should be applied ONLY to non-sortable columns
  if (columnDef.meta?.SROnly && !header.column.getCanSort()) {
    return (
      <Box
        data-testid="SROnlyStyles"
        component="span"
        sx={accessibilityStyles.SROnly}
      >
        {headerContent}
      </Box>
    );
  }

  return <Box component="span">{headerContent}</Box>;
};

function ReactTable<T extends object>(props: TReactTableProps<T>) {
  const {
    tableInstance,
    isHeadless = false,
    rowsPerPageOptions = ROWS_PER_PAGE_OPTIONS,
    isPaginated,
    bodyRowHover = true,
    sx,
    onRowClick,
    bodyRowAlign,
    tabs,
    EmptyStateComponent,
    onTabChange,
    rowRenderer = (props) => <ReactTableRow<T> {...props} />,
    tabValue,
    ...tableProps
  } = props;

  const { t } = useTranslation("common");

  const {
    getHeaderGroups,
    getAllColumns,
    getRowModel,
    getState,
    setPageIndex: gotoPage,
    setPageSize,
    getToggleAllPageRowsSelectedHandler,
    resetRowSelection,
  } = tableInstance;

  const allRows = tableInstance.getCoreRowModel().rows;
  const isSelectable = tableInstance.options.enableMultiRowSelection;
  const isRowDisabled = tableInstance.options.meta?.isRowDisabled;
  const expandableRowProps = tableInstance.options.meta?.expandableRowProps;
  const rowStyles = tableInstance.options.meta?.rowStyles;

  useEffect(() => {
    // this is used for resetting the row selection state, if the data provided to the react table has been updated
    if (isSelectable) {
      resetRowSelection();
    }
  }, [isSelectable, resetRowSelection, allRows]);

  const headerGroups = getHeaderGroups();
  const columns = getAllColumns();
  const rows = getRowModel().rows;

  const {
    pagination: { pageIndex, pageSize },
  } = getState();

  const styles = {
    ...tableStyles.root,
    ...(isHeadless && tableStyles.headless),
    ...sx,
  };

  const tableHeaderRowRef = useRef<any>(null);

  const tableBodyRef = useRef(null);

  const handleKeyDownColumnCell = (
    event: KeyboardEvent<HTMLTableCellElement>,
    column: TColumn<T>,
  ) => {
    const currentHeaderCell = tableHeaderRowRef?.current?.children.namedItem(
      column.id,
    );

    switch (event.code) {
      case Keyboard.Enter:
        column.getCanSort() && column.toggleSorting();
        break;
      case Keyboard.ArrowRight:
        const nextFocusableElement =
          currentHeaderCell?.nextElementSibling &&
          getAdjacentElementWithTabIndex(
            currentHeaderCell?.nextElementSibling,
            "next",
          );

        nextFocusableElement && nextFocusableElement.focus();
        break;
      case Keyboard.ArrowLeft:
        const prevFocusableElement =
          currentHeaderCell?.previousElementSibling &&
          getAdjacentElementWithTabIndex(
            currentHeaderCell?.previousElementSibling,
            "prev",
          );

        prevFocusableElement && prevFocusableElement.focus();
        break;
      default:
        break;
    }
  };

  const rowSpanColumns = {} as TRowSpanColumn;

  headerGroups.forEach((headerGroup) => {
    headerGroup.headers.forEach(({ column: { columnDef, id } }) => {
      if (columnDef.meta?.rowSpan && !rowSpanColumns[id]) {
        rowSpanColumns[id] = {
          depth: headerGroup.depth,
          rowSpan: columnDef.meta.rowSpan,
        };
      }
    });
  });

  return (
    <TableContainer sx={styles}>
      {tabs && (
        <ReactTableTabs
          onChange={onTabChange}
          tabs={tabs}
          isControlled={Object.hasOwn(props, "tabValue")}
          value={tabValue}
        />
      )}
      {isEmpty(rows) ? (
        EmptyStateComponent
      ) : (
        <MuiTable
          role="table"
          data-testid="table"
          data-qaid="table"
          {...tableProps}
        >
          {!isEmpty(columns) && !isHeadless && (
            <TableHead
              data-testid="table-header-row"
              data-qaid="table-header-row"
            >
              {headerGroups.map((headerGroup, rowIndex) => (
                <TableRow
                  ref={tableHeaderRowRef}
                  key={`table-header-row-${headerGroup.id}`}
                >
                  {isSelectable && (
                    <SelectionCell
                      onChange={getToggleAllPageRowsSelectedHandler()}
                      indeterminate={tableInstance.getIsSomePageRowsSelected()}
                      checked={tableInstance.getIsAllPageRowsSelected()}
                      ariaLabel={t("ariaLabels.selectAllCheckbox")}
                    />
                  )}
                  {headerGroup.headers
                    .filter((header) => {
                      const rowSpanColumn = rowSpanColumns[header.column.id];

                      return !(
                        rowSpanColumn?.rowSpan > rowIndex &&
                        rowIndex > rowSpanColumn?.depth
                      );
                    })
                    .map((header) => {
                      const column = header.column;

                      return (
                        column.getIsVisible() && (
                          <TableCell
                            id={column.id}
                            sx={getHeaderStyles(header)}
                            key={`table-cell-${column.id}`}
                            data-testid="table-header-cell"
                            data-qaid="table-header-cell"
                            onClick={
                              column.getCanSort()
                                ? column.getToggleSortingHandler()
                                : undefined
                            }
                            title={getColumnTitle<T>(column, t)}
                            width={
                              column?.columnDef?.meta?.width ||
                              DEFAULT_COLUMN_WIDTH
                            }
                            colSpan={
                              column?.columnDef?.meta?.colSpan ||
                              DEFAULT_COLUMN_AMOUNT
                            }
                            rowSpan={
                              column?.columnDef?.meta?.rowSpan ||
                              DEFAULT_ROW_AMOUNT
                            }
                            sortDirection={column.getIsSorted()}
                            onKeyDown={(event) =>
                              handleKeyDownColumnCell(event, column)
                            }
                          >
                            <Tooltip
                              title={column.columnDef.meta?.headerTooltip}
                            >
                              {column.getCanSort() ? (
                                <TableSortLabel
                                  active={Boolean(column.getIsSorted())}
                                  direction={
                                    column.getIsSorted() === DESCENDING
                                      ? DESCENDING
                                      : ASCENDING
                                  }
                                >
                                  {renderHeader(header)}
                                </TableSortLabel>
                              ) : (
                                renderHeader(header)
                              )}
                            </Tooltip>
                          </TableCell>
                        )
                      );
                    })}
                </TableRow>
              ))}
            </TableHead>
          )}
          <TableBody ref={tableBodyRef}>
            {rows.map((pageRow: Row<T>) => {
              const rowHover = isFunction(bodyRowHover)
                ? bodyRowHover(pageRow)
                : bodyRowHover;

              return (
                <Fragment key={`table-row-${pageRow.id}`}>
                  {rowRenderer({
                    rowData: pageRow,
                    onClick: onRowClick,
                    isDisabled: isRowDisabled?.(pageRow),
                    isSelectable: pageRow.getCanSelect(),
                    tableBodyRef: tableBodyRef,
                    hover: rowHover,
                    rowAlign: bodyRowAlign,
                    expandableRowProps: expandableRowProps,
                    rowStyles: rowStyles,
                  })}
                </Fragment>
              );
            })}
          </TableBody>
          {isPaginated && (
            <TableFooter>
              <TableRow>
                <TablePagination
                  rowsPerPageOptions={rowsPerPageOptions}
                  SelectProps={{
                    inputProps: {
                      "aria-label": t("accessibility.label.rowsPerPage"),
                      "aria-labelledby": "rowsPerPage",
                    },
                    id: "rowsPerPage",
                  }}
                  labelRowsPerPage={t("accessibility.label.rowsPerPage")}
                  data-testid="select-rows-per-page"
                  data-qaid="select-rows-per-page"
                  count={tableInstance.getRowCount()}
                  page={tableInstance.getRowCount() <= pageSize ? 0 : pageIndex}
                  onPageChange={(e, newPage) => gotoPage(newPage)}
                  rowsPerPage={pageSize}
                  onRowsPerPageChange={(e) =>
                    setPageSize(Number(e.target.value))
                  }
                />
              </TableRow>
            </TableFooter>
          )}
        </MuiTable>
      )}
    </TableContainer>
  );
}

export default ReactTable;
