import { useTheme } from "@mui/material";
import {
  DataGridPremium,
  GridColDef,
  GridColumnGroupingModel,
  GridRenderCellParams,
} from "@mui/x-data-grid-premium";
import { useTenantTranslation, useUnitsFormatter } from "hooks/formatters";
import { useLanguage } from "hooks/settings";
import { useMemo } from "react";
import {
  MaterialGroupingName,
  MaterialGroupName,
  SupportedLanguage,
  TimeAggregation,
} from "src/store/api/generatedApi";
import { formatNumber } from "src/utils/formatNumber";
import { BarCell } from "../BarCell";
import { defaultTableSx } from "../tableFormatting";
import { ConsumptionSource, GroupedConsumptionSeries } from "../types";

type ConsumptionTableProps = {
  consumptionSeries: GroupedConsumptionSeries;
  selectedDate: Date | null;
  selectedTimeAggregation: TimeAggregation | null;
  selectedMaterialGrouping: MaterialGroupingName | null;
};

type ConsumptionRecord = {
  materialGroup: MaterialGroupName;
  consumptionSource: ConsumptionSource;
  consumedMass: number;
  consumedFraction: number;
};

type ConsumptionRow = {
  // a consumption record that has been pivoted to have
  // columns for each consumption source
  id: number;
  materialGroup: MaterialGroupName;
} & {
  [key: `${ConsumptionSource}_mass`]: number;
  [key: `${ConsumptionSource}_fraction`]: number;
};

export const ConsumptionTable = ({
  consumptionSeries,
  selectedDate,
  selectedTimeAggregation,
  selectedMaterialGrouping,
}: ConsumptionTableProps) => {
  const records = useMemo(
    () =>
      buildRecords(
        consumptionSeries,
        selectedDate,
        selectedTimeAggregation,
        selectedMaterialGrouping
      ),
    [
      consumptionSeries,
      selectedDate,
      selectedTimeAggregation,
      selectedMaterialGrouping,
    ]
  );

  const rows = useMemo(
    () =>
      buildRowsFromRecords(
        records,
        Array.from(consumptionSeries.availableConsumptionSources)
      ),
    [records, consumptionSeries.availableConsumptionSources]
  );

  const totalsRow = useMemo(
    () =>
      buildTotalsRow(
        rows.length,
        records,
        Array.from(consumptionSeries.availableConsumptionSources)
      ),
    [records, consumptionSeries.availableConsumptionSources, rows.length]
  );

  const { columns, columnGroupingModel } = buildColumns(
    consumptionSeries,
    rows,
    totalsRow
  );

  return (
    <div style={{ display: "flex", flexDirection: "column" }}>
      <DataGridPremium
        rows={rows}
        columns={columns}
        columnGroupingModel={columnGroupingModel}
        sx={defaultTableSx}
        pinnedRows={{
          bottom: [totalsRow],
        }}
        initialState={{
          sorting: { sortModel: [{ field: "materialGroup", sort: "asc" }] },
        }}
      />
    </div>
  );
};

const buildRecords = (
  consumptionSeries: GroupedConsumptionSeries,
  selectedDate: Date | null,
  selectedTimeAggregation: TimeAggregation | null,
  selectedMaterialGrouping: MaterialGroupingName | null
): ConsumptionRecord[] => {
  const records: ConsumptionRecord[] = [];

  const filteredSeries = consumptionSeries.items.filter(
    (item) =>
      item[0].materialGrouping === selectedMaterialGrouping &&
      item[0].timeAggregation === selectedTimeAggregation
  );

  for (const [key, series] of filteredSeries) {
    for (const [index, dateString] of series.date.entries()) {
      const date = new Date(dateString);

      if (!selectedDate || date.toISOString() === selectedDate.toISOString()) {
        const consumedMass = series.consumed_mass[index];
        const consumedFraction = series.consumed_fraction[index];

        if (consumedFraction && consumedMass) {
          records.push({
            materialGroup: key.materialGroup,
            consumptionSource: key.consumptionSource,
            consumedFraction: consumedFraction,
            consumedMass: consumedMass,
          });
        }
      }
    }
  }

  return records;
};

const buildRowsFromRecords = (
  records: ConsumptionRecord[],
  availableConsumptionSources: ConsumptionSource[]
): ConsumptionRow[] => {
  const rowsByMaterialGroup = new Map<string, ConsumptionRow>();
  let rowId = 0;

  // Create the rows for each material group, with mass and fraction fields
  // for every consumption source
  for (const record of records) {
    const materialGroupKey = String(record.materialGroup);

    if (!rowsByMaterialGroup.has(materialGroupKey)) {
      const sourceFields: {
        [key: `${ConsumptionSource}_mass`]: number;
        [key: `${ConsumptionSource}_fraction`]: number;
      } = {};
      for (const source of availableConsumptionSources) {
        sourceFields[`${source}_mass`] = 0;
        sourceFields[`${source}_fraction`] = 0;
      }

      rowsByMaterialGroup.set(materialGroupKey, {
        id: rowId++,
        materialGroup: record.materialGroup,
        ...sourceFields,
      });
    }
  }

  // Fill in the values for each material group
  for (const record of records) {
    const materialGroupKey = String(record.materialGroup);
    const row = rowsByMaterialGroup.get(materialGroupKey)!;

    const source = String(record.consumptionSource);
    row[`${source}_mass`] = record.consumedMass;
    row[`${source}_fraction`] = record.consumedFraction;
  }

  return Array.from(rowsByMaterialGroup.values());
};

const buildTotalsRow = (
  id: number,
  records: ConsumptionRecord[],
  availableConsumptionSources: ConsumptionSource[]
): ConsumptionRow => {
  const { t } = useTenantTranslation();

  const totalMassesBySource = new Map<ConsumptionSource, number>();

  for (const record of records) {
    totalMassesBySource.set(
      record.consumptionSource,
      (totalMassesBySource.get(record.consumptionSource) ?? 0) +
        record.consumedMass
    );
  }

  const sourceFields: {
    [key: `${ConsumptionSource}_mass`]: number;
    [key: `${ConsumptionSource}_fraction`]: number;
  } = {};
  for (const source of availableConsumptionSources) {
    const sourceKey = String(source);
    sourceFields[`${sourceKey}_mass`] = totalMassesBySource.get(source) ?? 0;
    sourceFields[`${sourceKey}_fraction`] = 1.0;
  }

  return {
    id: id,
    materialGroup: t("Total"),
    ...sourceFields,
  };
};

const buildColumns = (
  consumptionSeries: GroupedConsumptionSeries,
  rows: ConsumptionRow[],
  totalsRow: ConsumptionRow
): {
  columns: GridColDef<ConsumptionRow>[];
  columnGroupingModel: GridColumnGroupingModel;
} => {
  const units = useUnitsFormatter(false);
  const language = useLanguage();
  const theme = useTheme();

  const formatNumberLocal = formatNumber(language as SupportedLanguage);

  const columns: GridColDef<ConsumptionRow>[] = [
    { field: "materialGroup", headerName: "Material Group", flex: 3 },
  ];
  const columnGroupingModel: GridColumnGroupingModel = [];

  for (const source of consumptionSeries.availableConsumptionSources) {
    columns.push({
      field: `${source}_mass`,
      type: "number",
      headerName: units("mass"),
      flex: 2,
      align: "right",
      renderCell: (params: GridRenderCellParams<ConsumptionRow, number>) =>
        params.row.id == rows.length ? (
          <div>
            {formatNumberLocal(0, true)(totalsRow[`${source}_mass`] ?? 0)}
          </div>
        ) : (
          <BarCell
            value={params.value!}
            formattedValue={formatNumberLocal(0, true)(params.value!)}
            maxValue={totalsRow[`${source}_mass`] ?? 0}
            color={theme.palette.data.blue}
          />
        ),
      cellClassName: "coloured-cell",
    });

    columns.push({
      field: `${source}_fraction`,
      type: "number",
      headerName: "%",
      flex: 1.5,
      align: "right",
      valueFormatter: (value?: number) =>
        !value ? "" : `${formatNumberLocal(1, true)(value * 100)} %`,
    });

    columnGroupingModel.push({
      groupId: source,
      headerName: source,
      children: [
        {
          field: `${source}_mass`,
        },
        { field: `${source}_fraction` },
      ],
    });
  }

  return { columns, columnGroupingModel };
};
