import { StockData, Strategies } from "../api/api";
import { StrategiesTransaction } from "../api/api";
import { roundNumberToTwoDecimalPlaces } from "../utils";

interface StrategyPerformanceStockData
    extends StockData,
        StrategiesTransaction {
    mktValue: number;
    sharesOwned?: number;
    cash?: number;
    buyHoldMktValue: number;
    buyHoldShares?: number;
    buyHoldCash?: number;
}

class BuyHoldStrategyPerformance {
    assetData: StockData[];

    constructor(assetData: StockData[]) {
        this.assetData = assetData;
    }

    getSharesOwnedAndCashRemainder(
        startBalance: number,
        commissions: number
    ): { sharesOwned: number; cashRemainder: number } {
        let firstMarketPrice = this.assetData[0].close;
        let sharesOwned = Math.floor(
            (startBalance - commissions) / firstMarketPrice
        );
        // Get cash remaining
        let cashRemainder =
            startBalance - sharesOwned * firstMarketPrice - commissions;

        return { sharesOwned, cashRemainder };
    }
}

export default class StrategyPerformance extends BuyHoldStrategyPerformance {
    previousTransactions: StrategiesTransaction[];
    mergedData: StrategyPerformanceStockData[];
    constructor({
        previousTransactions,
        assetData,
    }: {
        previousTransactions: StrategiesTransaction[];
        assetData: StockData[];
    }) {
        super(assetData);
        this.mergedData = [];
        this.previousTransactions = previousTransactions;
    }

    public getMarketValueOfInvestment(
        startBalance: number,
        commissions: number = 4,
        taperingRatio: number = 1
    ): StrategyPerformanceStockData[] {
        if (taperingRatio < 0) {
            throw new Error("Tapering ratio can't be negative.");
        }
        let mergedData = this.mergePreviousTransactionsAndStockData();
        let { sharesOwned, cashRemainder } =
            this.getSharesOwnedAndCashRemainder(startBalance, commissions);
        // Initialise market values
        mergedData.forEach((data) => {
            data.mktValue = 0;
            data.sharesOwned = 0;
            data.cash = startBalance;
            data.buyHoldShares = sharesOwned;
            data.buyHoldCash = cashRemainder;
        });

        for (let i = 0; i < mergedData.length; i++) {
            this.processTransaction(mergedData, i, commissions, taperingRatio);
        }
        this.mergedData = mergedData;
        return mergedData;
    }
    public getExposedToTotalDays() {
        if (this.mergedData.length <= 0) throw Error("No market data.");

        let nonOwnershipDays = this.mergedData.filter(
            (d) => d.sharesOwned === undefined || d.sharesOwned <= 0
        ).length;
        let totalDays = this.mergedData.length;
        return roundNumberToTwoDecimalPlaces(
            (totalDays - nonOwnershipDays) / totalDays
        );
    }
    public calculateProfitFactor(): number {
        const returns: number[] = [];
        let values = this.previousTransactions;
        for (let i = 1; i < values.length; i++) {
            returns.push(values[i].price - values[i - 1].price);
        }
        const grossProfit = returns.reduce(
            (acc, r) => (r > 0 ? acc + r : acc),
            0
        );
        const grossLoss = Math.abs(
            returns.reduce((acc, r) => (r < 0 ? acc + r : acc), 0)
        );

        // Step 3: Calculate the Profit Factor
        if (grossLoss === 0) {
            // Avoid division by zero
            return grossProfit > 0 ? Infinity : 0;
        }

        const profitFactor = grossProfit / grossLoss;
        return profitFactor;
    }

    public getNetProfit(values: number[], startBalance: number): number {
        let lastValue = values.at(-1);
        let netProfit = NaN;
        if (lastValue && startBalance) {
            netProfit = lastValue - startBalance;
        }
        return netProfit;
    }

    public getVolatilityOfMktValue(values: number[]): number {
        let length = values.length;
        if (length < 2) {
            throw new Error(
                "At least 2 values are required to calculate volatility."
            );
        }
        const returns: number[] = [];
        for (let i = 1; i < values.length; i++) {
            const dailyReturn = (values[i] - values[i - 1]) / values[i - 1];
            returns.push(dailyReturn);
        }
        const meanReturn =
            returns.reduce((acc, val) => acc + val, 0) / returns.length;

        const variance =
            returns.reduce(
                (acc, val) => acc + Math.pow(val - meanReturn, 2),
                0
            ) / returns.length;

        const volatility = Math.sqrt(variance) * 100;
        return roundNumberToTwoDecimalPlaces(volatility);
    }

    private mergePreviousTransactionsAndStockData(): StrategyPerformanceStockData[] {
        let transactionsMap: { [dateString: string]: StrategiesTransaction } =
            {};
        this.previousTransactions.forEach((transaction) => {
            transactionsMap[transaction.dateString] = transaction;
        });

        const result = this.assetData.map((data) => {
            const transaction = transactionsMap[data.dateString];
            return {
                ...transaction,
                ...data,
                mktValue: 0,
                buyHoldMktValue: 0,
            };
        });

        // Slice result

        return result;
    }
    public getAnnualizedReturn(startBalance: number, values: number[]): number {
        let totalDays = values.length;
        if (totalDays <= 0) {
            throw Error("Not enough days");
        }
        const years = totalDays / 365;
        let endBalance = values.at(-1);
        const annualizedReturn =
            endBalance && (endBalance / startBalance) ** (1 / years) - 1;
        if (annualizedReturn === undefined) throw Error("Annualized Return");
        return roundNumberToTwoDecimalPlaces(annualizedReturn);
    }
    private formatDateToYYYYMMDD(date: Date): string {
        return date.toISOString().split("T")[0];
    }

    public getMaxDrawDown(values: number[]): number {
        let maxDrawdown = 0;
        let peak = values[0];

        for (let i = 1; i < values.length; i++) {
            if (values[i] > peak) {
                peak = values[i];
            }

            const drawdown = (peak - values[i]) / peak;
            maxDrawdown = Math.max(maxDrawdown, drawdown);
        }
        return roundNumberToTwoDecimalPlaces(maxDrawdown);
    }

    private processTransaction(
        mergedData: any[],
        index: number,
        commissions: number,
        taperingRatio: number
    ): void {
        const data = mergedData[index];
        if (data.action === "Buy") {
            const numberOfSharesToPurchase = Math.floor(
                (data.cash - commissions) / data.open
            );
            const costBasis = numberOfSharesToPurchase * data.open;
            const netPurchase = costBasis + commissions;
            for (let j = index; j < mergedData.length; j++) {
                mergedData[j].sharesOwned += numberOfSharesToPurchase;
                mergedData[j].cash -= netPurchase;
            }
        } else if (data.action === "Sell") {
            const numberOfSharesToSell = Math.round(
                data.sharesOwned * taperingRatio
            );
            const totalSales = numberOfSharesToSell * data.open;
            const netSales = totalSales - commissions;
            for (let j = index; j < mergedData.length; j++) {
                mergedData[j].sharesOwned -= numberOfSharesToSell;
                mergedData[j].cash += netSales;
            }
        }

        data.mktValue = parseFloat(
            (data.sharesOwned * data.close + data.cash).toFixed(2)
        );
        data.buyHoldMktValue =
            data.buyHoldShares * data.close + data.buyHoldCash;
    }
}
