import { TableColumnId } from '@fluentui/react-components'
import { ColumnWidthState, ExtendedTableColumnDefinition } from './ExtendedTableColumn'

function getColumnById(state: ColumnWidthState[], columnId: TableColumnId) {
  return state.find((c) => c.columnId === columnId)
}

function getTotalWidth(state: ColumnWidthState[]): number {
  return state.reduce((sum, column) => sum + column.width + column.padding, 0)
}

/* returns an updated state */
function setColumnProperty(
  localState: ColumnWidthState[],
  columnId: TableColumnId,
  property: keyof ColumnWidthState,
  value: number,
) {
  const currentColumn = getColumnById(localState, columnId)

  if (!currentColumn || currentColumn?.[property] === value) {
    return localState
  }

  const updatedColumn = { ...currentColumn, [property]: value }

  const newState = localState.reduce((acc: ColumnWidthState[], current: ColumnWidthState) => {
    if (current.columnId === updatedColumn.columnId) {
      return [...acc, updatedColumn]
    }
    return [...acc, current]
  }, [])

  return newState
}

/* Adjusts the widths of the columns to fit the container */
export function adjustColumnWidthsToFitContainer<TItem>(
  columns: ExtendedTableColumnDefinition<TItem>[],
  containerWidth: number,
) {
  const state: ColumnWidthState[] = columns.map((column) => ({
    columnId: column.columnId,
    width: column.columnSize.idealWidth,
    minWidth: column.columnSize.minWidth,
    idealWidth: column.columnSize.idealWidth,
    padding: column.columnSize.padding,
  }))
  let newState = [...state]
  const totalWidth = getTotalWidth(newState)
  const flex2Columns = columns.filter((col) => col.flex === 2)
  const flex1Columns = columns.filter((col) => col.flex === 1)
  if (totalWidth < containerWidth) {
    let difference = containerWidth - totalWidth

    flex1Columns.forEach((col) => {
      newState = setColumnProperty(newState, col.columnId, 'width', col.columnSize.idealWidth)
    })

    difference = containerWidth - getTotalWidth(newState)
    const flex2Width = difference / flex2Columns.length
    flex2Columns.forEach((col) => {
      newState = setColumnProperty(newState, col.columnId, 'width', col.columnSize.idealWidth + flex2Width)
    })
  } else if (totalWidth > containerWidth) {
    flex1Columns.forEach((col) => {
      newState = setColumnProperty(newState, col.columnId, 'width', col.columnSize.idealWidth)
    })
    flex2Columns.forEach((col) => {
      newState = setColumnProperty(newState, col.columnId, 'width', col.columnSize.idealWidth)
    })
  }
  return newState
}
