import { Box, LinearProgress, Typography } from "@mui/material"
import { flexRender, getCoreRowModel, getExpandedRowModel, useReactTable } from "@tanstack/react-table"
import { useVirtualizer } from "@tanstack/react-virtual"
import React, { useCallback, useEffect, useRef, useState } from "react"

import { columns } from "./productColumns"
import { ProductRow } from "./ProductRow"
import { ExpandedContent } from "./ProductRowExpandedContent"

import type { Product } from "@/graphql/codegen/graphql"

interface ProductListTableProps {
  data: Product[]
  loadMore: () => Promise<void>
  fetching: boolean
  error?: { message: string }
}

const ROW_HEIGHT = 64

// Define a custom type for our ExpandedState
type ProductExpandedState = Record<number, boolean>

export const ProductListTable: React.FC<ProductListTableProps> = ({ data, loadMore, fetching, error }) => {
  const [expanded, setExpanded] = useState<ProductExpandedState>(() =>
    Object.fromEntries(data.map(({ productId }) => [productId, true]))
  )
  const [expandedData, setExpandedData] = useState<Record<number, unknown>>({})
  const expandedHeightsRef = useRef<Record<number, number>>({})
  const parentRef = useRef<HTMLDivElement>(null)

  const loadExpandedData = useCallback(
    (productId: number) => {
      if (!expandedData[productId]) {
        const simulatedData = {
          /* ... */
        }
        setExpandedData((prev) => ({ ...prev, [productId]: simulatedData }))
      }
    },
    [expandedData]
  )

  const table = useReactTable({
    data,
    columns,
    state: { expanded },
    onExpandedChange: (updater) => {
      setExpanded((prev) => {
        const next = updater instanceof Function ? updater(prev) : updater
        Object.keys(next).forEach((rowId) => {
          const productId = parseInt(rowId, 10)
          if ((next as Record<string, boolean>)[rowId] && !expandedData[productId]) {
            loadExpandedData(productId)
          }
        })
        return next as ProductExpandedState
      })
    },
    getCoreRowModel: getCoreRowModel(),
    getExpandedRowModel: getExpandedRowModel(),
    getRowCanExpand: () => true,
    getRowId: (row) => row.productId.toString(),
  })

  const { rows } = table.getRowModel()

  const estimateSize = useCallback(
    (index: number) => {
      const row = rows[index]
      return expanded[row.original.productId]
        ? ROW_HEIGHT + (expandedHeightsRef.current[row.original.productId] || 0)
        : ROW_HEIGHT
    },
    [rows, expanded]
  )

  const rowVirtualizer = useVirtualizer({
    count: rows.length,
    getScrollElement: () => parentRef.current,
    estimateSize,
    overscan: 5,
    getItemKey: (index) => rows[index].original.productId,
  })

  const loadMoreItems = useCallback(async () => {
    if (!fetching) {
      await loadMore()
    }
  }, [fetching, loadMore])

  useEffect(() => {
    const scrollElement = parentRef.current
    if (!scrollElement) return

    const onScroll = () => {
      const { scrollHeight, scrollTop, clientHeight } = scrollElement
      if (scrollHeight - scrollTop - clientHeight < 500) {
        loadMoreItems()
      }
    }

    scrollElement.addEventListener("scroll", onScroll)
    return () => scrollElement.removeEventListener("scroll", onScroll)
  }, [loadMoreItems])

  const handleContentLoad = useCallback(
    (productId: number, height: number) => {
      if (expandedHeightsRef.current[productId] !== height) {
        expandedHeightsRef.current[productId] = height
        rowVirtualizer.measure()
      }
    },
    [rowVirtualizer]
  )

  const getFlexibleColumnWidth = useCallback((header: { column: { getSize: () => number } }) => {
    const configuredWidth = header.column.getSize()
    return configuredWidth === 150 || !configuredWidth ? "1fr" : `${configuredWidth}px`
  }, [])

  return (
    <Box className='flex h-full flex-col'>
      {fetching && <LinearProgress className='sticky top-0 z-10' />}
      <Box className='sticky top-0 z-10 bg-gray-100 text-gray-700'>
        <Box
          className='grid border-b border-gray-200'
          style={{
            height: ROW_HEIGHT,
            gridTemplateColumns: table.getFlatHeaders().map(getFlexibleColumnWidth).join(" "),
          }}
        >
          {table.getFlatHeaders().map((header) => (
            <Box key={header.id} className='flex items-center px-2'>
              {flexRender(header.column.columnDef.header, header.getContext())}
            </Box>
          ))}
        </Box>
      </Box>
      <Box className='grow overflow-auto' ref={parentRef}>
        <Box
          className='relative w-full'
          style={{ height: data.length ? `${rowVirtualizer.getTotalSize()}px` : `${ROW_HEIGHT}px` }}
        >
          {!data.length ? (
            <Box
              className='absolute left-0 top-0 flex w-full flex-col'
              style={{
                transform: "translateY(0px)",
              }}
            >
              <Box
                className='grid items-center border-b border-gray-200 bg-white text-gray-700'
                style={{
                  gridTemplateColumns: table.getFlatHeaders().map(getFlexibleColumnWidth).join(" "),
                  height: ROW_HEIGHT,
                }}
              >
                <Box className='col-span-full px-2'>
                  <Typography variant='body2' className='truncate text-center'>
                    No results found according to your search criteria
                  </Typography>
                </Box>
              </Box>
            </Box>
          ) : (
            rowVirtualizer.getVirtualItems().map((virtualRow) => {
              const row = rows[virtualRow.index]
              const isExpanded = Boolean(expanded[row.original.productId])

              return (
                <ProductRow
                  key={row.original.productId}
                  row={row}
                  isExpanded={isExpanded}
                  virtualRow={virtualRow}
                  getFlexibleColumnWidth={getFlexibleColumnWidth}
                >
                  {isExpanded && (
                    <ExpandedContent
                      row={row.original as Product}
                      productId={row.original.productId}
                      companyId={row.original.companyId}
                      sourceId={row.original.locationsAssociations[0]?.locationAssociationId}
                      onContentLoad={(height) => handleContentLoad(row.original.productId, height)}
                    />
                  )}
                </ProductRow>
              )
            })
          )}
        </Box>
      </Box>
      {error && (
        <Typography color='error' className='p-4'>
          {error.message}
        </Typography>
      )}
    </Box>
  )
}
