import cloneDeep from "lodash/cloneDeep";

import { varWithAggToName } from "formatters";
import {
  Filter,
  OutcomeSelectionConfig,
  AggregationType,
} from 'types';
import { SimResultRow } from 'api/types';


type VarWithAggSetting = {
  type: AggregationType;
  varName: string;
  time?: number;
};

// Flatten each variable/aggregation pair into unique object. varCollections has
// array of vars for each type of aggregation.
export const flattenMainVarCollections = (
  varCollections: OutcomeSelectionConfig[]
): VarWithAggSetting[] => {
  const varsWithAggSettings = [];
  for (let i = 0; i < varCollections.length; i++) {
    const collection = varCollections[i];
    const varNames = collection.variableNames;
    for (let j = 0; j < varNames.length; j++) {
      varsWithAggSettings.push({
        varName: varNames[j],
        type: collection.type,
        time: collection.time,
      });
    }
  }
  return varsWithAggSettings;
};

/**
 * GetSimSummaryStats returns the simulation level means and totals. These are used in filtering
 * and in creating the filtering histogram
 */
export function getSimSummaryStats(data, variableNames) {
  let dataForCompute = cloneDeep(data.filter((row) => row.t > 0));
  const sortingFn = (r1, r2) =>
    r1.simId === r2.simId ? r1.policyId - r2.policyId : r1.simId - r2.simId;
  dataForCompute.sort(sortingFn);

  let means = [];
  let totals = [];

  let currentSimId = 0;
  let currentPolicyId = 0;
  let sumAccumulator = Object.fromEntries(
    variableNames.map((vname) => [vname, 0])
  );
  let nonNullValAccumulator = Object.fromEntries(
    variableNames.map((vname) => [vname, 0])
  );

  dataForCompute.forEach((currentRow) => {
    if (
      currentRow.simId === currentSimId &&
      currentRow.policyId === currentPolicyId
    ) {
      variableNames.forEach((v) => (sumAccumulator[v] += currentRow[v]));
      variableNames.forEach(
        (v) => (nonNullValAccumulator[v] += currentRow[v] === null ? 0 : 1)
      );
    } else {
      totals.push({
        simId: currentSimId,
        policyId: currentPolicyId,
        ...sumAccumulator,
      });
      means.push({
        simId: currentSimId,
        policyId: currentPolicyId,
        ...Object.fromEntries(
          variableNames.map((vname) => [
            vname,
            sumAccumulator[vname] / nonNullValAccumulator[vname],
          ])
        ),
      });
      currentSimId = currentRow.simId;
      currentPolicyId = currentRow.policyId;
      sumAccumulator = Object.fromEntries(
        variableNames.map((vname) => [vname, currentRow[vname]])
      );
      nonNullValAccumulator = Object.fromEntries(
        variableNames.map((vname) => [
          vname,
          currentRow[vname] === null ? 0 : 1,
        ])
      );
    }
  });
  return { mean: means, total: totals };
}

/**
 * filterData returns only the data with simId and policyId's that show up in simPolicyPairsAsJson
 * The determination of what sim and policy id's happens elsewhere, but this applies the filter
 */
function filterData(data: SimResultRow[], simPolicyPairsAsJson: string[]): SimResultRow[] {
  const out = data
    .filter((row) =>
      simPolicyPairsAsJson.includes(JSON.stringify([row.simId, row.policyId]))
    )
    .filter(
      (row) => row.t > 0 // time zero will have lots of nulls, should be ignored in aggregation
    );
  return out;
}

/**
 * getSimPolicyPairsPassingFilters takes filters in the format specified from the UI. It returns an array
 * of (stringified) simId x policyId pairs passing all filters. These are applied elsewhere
 */
function getSimPolicyPairsPassingFilters(
  filters: Filter[],
  data: SimResultRow[],
  policies: string[],
): string[] {
  const singlePeriodFilters = filters.filter((f) => f.type === "point");
  const aggregationFilters = filters.filter(
    (f) => f.type !== "point" && f.type !== null
  );

  let usablePairs = [
    ...new Set(
      data
        .map((row) => [row.simId, row.policyId])
        .map((obj) => JSON.stringify(obj))
    ),
  ];

  if (singlePeriodFilters.length) {
    const simPolicyPairsFrom1PeriodFilt = singlePeriodFilters.map((f) =>
      data
        .filter((row) => Number(row.t) === Number(f.time))
        .filter((row) => f.max === null || row[f.variable] <= f.max)
        .filter((row) => f.min === null || row[f.variable] >= f.min)
        // HACK: using strings b/c js comparison of arrays doesn't do what we want ([1, 2] is not equal to [1, 2])
        .map((row) => JSON.stringify([row.simId, row.policyId]))
    );

    usablePairs = simPolicyPairsFrom1PeriodFilt.reduce(intersect, usablePairs);
  }

  if (aggregationFilters.length) {
    const simSummaryStats = getSimSummaryStats(
      data,
      aggregationFilters.map((f) => f.variable)
    );

    const simPolicyPairsFromAggFilt = aggregationFilters.map((f) =>
      simSummaryStats[f.type]
        .filter((row) => f.max === null || row[f.variable] <= f.max)
        .filter((row) => f.min === null || row[f.variable] >= f.min)
        .map((row) => JSON.stringify([row.simId, row.policyId]))
    );

    usablePairs = simPolicyPairsFromAggFilt.reduce(intersect, usablePairs);
  }
  return usablePairs;
}

interface AggRow {
  simId: number;
  policyId: number;
  // plus additional field for "long name", e.g. "Mean Line Length", where value
  // is a number
}

/**
 * getHistogramData takes a list of simId and PolicyId pairs to keep. It applies that to filter data. It then
 * calculates the aggregations needed to be shown in the histogram graphs and returns that
 */
function getHistogramData(
  simPolicyPairsAsJson: string[],
  data: SimResultRow[],
  variablesWithAggSettings: VarWithAggSetting[],
): AggRow[] {
  // Do something akin to Pandas command dataBySimIdPid = data.set_index(['simId', 'policyId']).
  // Output is object of sim data keyed by ['simId', 'policyId']
  const dataBySimIdPid: { [stringifiedSimPolPair: string]: SimResultRow[] } = Object.fromEntries(
    simPolicyPairsAsJson.map((s) => [s, []])
  );
  for (let i = 0; i < data.length; i++) {
    const row = data[i];
    const key = JSON.stringify([row.simId, row.policyId]);
    dataBySimIdPid[key].push(row);
  }

  const aggRows = [];
  for (
    let ixSimIdPolicyId = 0;
    ixSimIdPolicyId < simPolicyPairsAsJson.length;
    ixSimIdPolicyId++
  ) {
    const key = simPolicyPairsAsJson[ixSimIdPolicyId];
    const [simId, policyId] = JSON.parse(key);
    const dataGroup = dataBySimIdPid[key];
    const aggRow = { simId: simId, policyId: policyId };
    for (let j = 0; j < variablesWithAggSettings.length; j++) {
      const varSettings = variablesWithAggSettings[j];
      const longVarName = varWithAggToName(varSettings);
      let value = null;
      if (varSettings.type === "mean") {
        let toSum = dataGroup.map((row) => row[varSettings.varName]);
        value = arrSum(toSum) / dataGroup.length;
      } else if (varSettings.type === "total") {
        value = arrSum(dataGroup.map((row) => row[varSettings.varName]));
      } else if (varSettings.type === "point") {
        const time = varSettings.time;
        const [row] = dataGroup.filter((row) => row.t === time);
        /* TODO: Sometimes dataGroup is undefined, but conditions are tricky to 
           reproduce. This is a bandaid so that we at least avoid a hard crash.
           But should try to track down root cause.
        */
        value = row?.[varSettings.varName];
      } else {
        throw new Error("unexpected type");
      }
      aggRow[longVarName] = value;
    }
    aggRows.push(aggRow);
  }
  return aggRows;
}

interface LeaderboardDatum {
  policyId: number;
  // additional field ("long name", value)
}

/**
 * Aggregates all simulations for each var in histogramData to show just the expected value of that metric for the leaderboard
 */
function getLeaderboardData(
  histogramData: AggRow[],
  policies: string[],
  // "Long" names
  vars: string[],
): LeaderboardDatum[] {
  const numPolicies = Math.max(policies.length, 1);

  const dataByPid = {};
  for (let i = 0; i < numPolicies; i++) {
    dataByPid[i] = [];
  }

  for (let i = 0; i < histogramData.length; i++) {
    const row = histogramData[i];
    dataByPid[row.policyId].push(row);
  }

  const outputRows = [];
  for (let i = 0; i < numPolicies; i++) {
    const row = { policyId: i };
    for (let j = 0; j < vars.length; j++) {
      const varName = vars[j];
      const values = dataByPid[i].map((row) => row[varName]);
      row[varName] = arrSum(values) / values.length;
    }
    outputRows.push(row);
  }
  return outputRows;
}

function arrSum(arr: (number|null)[]): (number|null) {
  let val = 0;
  for (let i = 0; i < arr.length; i++) {
    if (arr[i] === null) {
      return null; // maybe not the best approach, but if we don't do something like this we'll never know there was a problem
    }
    val += arr[i];
  }
  return val;
}

function intersect(a, b) {
  const setB = new Set(b);
  return a.filter((i) => setB.has(i));
}

/**
 * Primary function called by Explorer to get all data it needs to show
 */
export const getAggsForExplorer = (
  filters: Filter[],
  simResults: SimResultRow[],
  policyNames: string[],
  selectedPolicies: number[],
  // e.g. "Mean Line Length", "Period 180 Cumulative Profit"
  mainLongNames: string[],
  EXPLORE_NUM_SIMS: number,
  // Object per (variable, aggregation) pair
  mainVariablesWithAggSettings: VarWithAggSetting[],
) => {
  const simPolicyPairsAsJson = getSimPolicyPairsPassingFilters(
    filters,
    simResults,
    policyNames,
  );

  const simPolicyPairs: [number, number][] = simPolicyPairsAsJson.map((pairStr: string) =>
    JSON.parse(pairStr)
  );

  let mainFilteredData: SimResultRow[];
  // Number of simIds passing filters per policy id.
  let countsByPolicyId: number[];

  if (filters.length) {
    if (policyNames.length === 0) {
      countsByPolicyId = [0];
    } else {
      countsByPolicyId = policyNames.map( (_, iPol) => (
        simPolicyPairs.filter( ([simId, polId]) => polId === iPol).length
      ));
    }
    mainFilteredData = filterData(simResults, simPolicyPairsAsJson);
  } else {
    countsByPolicyId = policyNames.map(() => EXPLORE_NUM_SIMS);
    mainFilteredData = simResults.filter((row) => row.t > 0);
  }

  const mainHistogramDataAllPolicies: AggRow[] = getHistogramData(
    simPolicyPairsAsJson,
    mainFilteredData,
    mainVariablesWithAggSettings
  );

  const mainHistogramData: AggRow[] = mainHistogramDataAllPolicies.filter(
    (row) =>
      policyNames.length === 0 ||
      selectedPolicies === null ||
      selectedPolicies.includes(row.policyId)
  );

  const mainLeaderboardData = getLeaderboardData(
    mainHistogramDataAllPolicies,
    policyNames,
    mainLongNames
  );

  return [
    countsByPolicyId,
    mainFilteredData,
    mainHistogramData,
    mainLeaderboardData,
  ];
};
