# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: MIT

import glob
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Union

import pandas as pd
from openpyxl.cell import MergedCell
from openpyxl.styles import Alignment

from mpp import ViewAggregationLevel, SummaryViewDataFrameColumns as svdc

BASELINE_HEADER = 'Baseline values'
DELTA_HEADER = 'Δ Baseline (%)'

class MppComparer:

    def __init__(self, output_file_specifiers: List[Union[str, Path]], delta: float = 5, baseline=None,
                 file_name='comparison.xlsx'):
        self.__output_file_specifiers = output_file_specifiers
        self.__delta = delta
        self.__baseline = baseline
        self.__file_name = file_name
        self.__summary_views = {}

    @property
    def output_file_specifiers(self):
        return self.__output_file_specifiers

    @property
    def summary_views(self):
        return self.__summary_views

    def compare_files(self):
        self.__get_summary_views()
        self.__process_summary_views()

    def __get_summary_views(self):
        for agg_level in [ViewAggregationLevel.SYSTEM]:
            self.__summary_views[agg_level] = {}
        for output_file_specifier in self.__output_file_specifiers:
            self.__get_summary_view_files(output_file_specifier, '_system_view_summary',
                                          ViewAggregationLevel.SYSTEM)
            # self.__get_summary_view_files(output_file_specifier, '_socket_view_summary',
            #                               ViewAggregationLevel.SOCKET)


    def __get_summary_view_files(self, output_file_specifier, summary_view_type, agg_level):
        output_files = glob.glob(str(output_file_specifier) + f'*{summary_view_type}.csv')

        for output_file in output_files:
            summary_file_id = str(Path(output_file).stem).replace(summary_view_type, '')
            self.__summary_views[agg_level][summary_file_id] = pd.read_csv(output_file)

    def __process_summary_views(self):
        for view_type, summary_views in self.__summary_views.items():
            compare_data_processor = _CompareDataProcessorFactory.create(view_type, summary_views, self.__baseline,
                                                                         self.__delta)
            styled_summary_df, styled_delta_df = compare_data_processor.process()
            self._write(styled_summary_df, styled_delta_df)

    def _write(self, styled_summary_df, styled_delta_df):
        xl_file = self.__output_file_specifiers[0].parent / self.__file_name
        with pd.ExcelWriter(xl_file, 'openpyxl') as writer:
            # Write to Excel with styling
            styled_delta_df.to_excel(writer, sheet_name='system summary delta', engine='openpyxl')
            styled_summary_df.to_excel(writer, sheet_name='system summary', engine='openpyxl')

            # Load the workbook and access the worksheets
            summary_delta_sheet = writer.sheets['system summary delta']
            summary_sheet = writer.sheets['system summary']

            # Delete the index name row that is unnecessary
            summary_delta_sheet.delete_rows(3)
            summary_sheet.delete_rows(3)

            # Auto-adjust column widths for both sheets
            self.auto_adjust_column_widths(summary_delta_sheet)
            self.auto_adjust_column_widths(summary_sheet)

            self.format_columns_as_percentage(summary_delta_sheet)

            self.left_align_index(summary_delta_sheet)
            self.left_align_index(summary_sheet)

        print(f'Output written to: {xl_file}')

    @staticmethod
    def auto_adjust_column_widths(sheet):
        for col in sheet.iter_cols():
            max_length = 0
            column = col[0].column_letter if not isinstance(col[0], MergedCell) else col[1].column_letter
            for cell in col:
                try:
                    if len(str(cell.value)) > max_length:
                        max_length = len(cell.value)
                except:
                    pass
            adjusted_width = (max_length + 2)
            sheet.column_dimensions[column].width = adjusted_width
        return sheet

    @staticmethod
    def format_columns_as_percentage(sheet):
        for col in sheet.iter_cols(min_col=3, max_col=sheet.max_column, min_row=2, max_row=sheet.max_row):
            for cell in col:
                cell.number_format = '0.00%'

    @staticmethod
    def left_align_index(sheet):
        for cell in sheet['A']:
            cell.alignment = Alignment(horizontal='left')


class _CompareDataProcessor(ABC):

    def __init__(self, summary_views, baseline, delta):
        self._summary_views = summary_views
        self._baseline = baseline
        self._delta = delta
        self._baseline_column = None
        self._baseline_df = None
        self._excel_summary_df = None
        self._diff_df = None
        self._styled_summary_df = None
        self._styled_delta_df = None

    def _validate_preconditions(self):
        if self._baseline and self._baseline not in self._excel_summary_df.columns:
            raise ValueError(f'Baseline "{self._baseline}" not found in summary views. Available baselines: '
                             f'{self._excel_summary_df.columns}')

    def process(self):
        self.get_baseline()
        baseline_str = '--baseline not provided. ' if not self._baseline else ''

        print(f'{baseline_str}Baseline set to \'{self._baseline_column}\'')
        self.merge()
        self.get_diff()
        self.style()
        return self._styled_summary_df, self._styled_delta_df

    @abstractmethod
    def get_baseline(self):
        """
        Get the baseline summary view DataFrame
        @return: None, create self._baseline_df
        """
        pass

    @abstractmethod
    def merge(self):
        """
        Merge all summary views into a single DataFrame
        @return: None, create self._excel_summary_df
        """
        pass

    @abstractmethod
    def get_diff(self):
        """
        Calculate the percentage difference between each summary view and the baseline
        @return: None, create self._diff_df
        """
        pass

    @abstractmethod
    def style(self):
        """
        Apply styling to the summary and delta DataFrames
        @return: None, create self._styled_summary_df and self._styled_delta_df
        """
        pass

    @abstractmethod
    def write(self):
        pass


class SystemViewComparer(_CompareDataProcessor):

    def __init__(self, summary_views, baseline, delta):
        super().__init__(summary_views, baseline, delta)

    def get_baseline(self):
        default_baseline =  list(self._summary_views.keys())[0]
        self._baseline_column = self._baseline if self._baseline else default_baseline
        baseline_df = self._get_baseline_df(default_baseline)
        self._baseline_df = self._adjust_summary_df(self._baseline_column, baseline_df)
        self._excel_summary_df = self._baseline_df.to_frame()

    def merge(self):
        for file_name, summary_df in self._summary_views.items():
            if file_name == self._baseline_column:
                continue
            summary_df = self._adjust_summary_df(file_name, summary_df)

            self._excel_summary_df = pd.merge(self._excel_summary_df, summary_df, left_index=True,
                                              right_index=True, how='left')
        self._validate_preconditions()
        self._excel_summary_df = self._excel_summary_df[
            [self._baseline_column] + self._excel_summary_df.columns.difference([self._baseline_column]).tolist()]

        columns_with_headers = [((DELTA_HEADER if col != self._baseline_column else BASELINE_HEADER),
                                 col) for col in self._excel_summary_df.columns]
        self._excel_summary_df.columns = pd.MultiIndex.from_tuples(columns_with_headers)

    def get_diff(self):
        self._diff_df = self._excel_summary_df.iloc[:, 1:].apply(lambda x: (x - self._baseline_df) /
                                                                            self._baseline_df)
        self._diff_df = pd.concat([self._excel_summary_df[[BASELINE_HEADER]], self._diff_df], axis='columns')

    def style(self):
        def get_highlight_from_diff(row):
            row = self._diff_df.loc[row.name]
            return ['background-color: lightgray' if idx == 0 else 'background-color: yellow' if abs(
                val) * 100 >= self._delta else '' for idx, val in enumerate(row)]

        self._styled_summary_df = self._excel_summary_df.style.apply(get_highlight_from_diff, axis=1)
        self._styled_delta_df = self._diff_df.style.apply(get_highlight_from_diff, axis=1)

    def write(self):
        pass

    def _get_baseline_df(self, default_baseline):
        try:
            baseline_df = self._summary_views[self._baseline_column]
        except KeyError:
            print(f'Baseline "{self._baseline}" not found in summary views. Available baselines: '
                  f'{list(self._summary_views.keys())}')
            self._baseline_column = default_baseline
            baseline_df = self._summary_views[self._baseline_column]
        return baseline_df

    @staticmethod
    def _adjust_summary_df(file_name, summary_df):
        summary_df.rename(columns={svdc.AGGREGATED: file_name}, inplace=True)
        summary_df.set_index(summary_df.columns[0], inplace=True)
        return summary_df[file_name]


class SocketViewComparer(_CompareDataProcessor):

    def __init__(self, summary_views, baseline, delta):
        super().__init__(summary_views, baseline, delta)

    def get_baseline(self):
        pass

    def merge(self):
        pass

    def get_diff(self):
        pass

    def style(self):
        pass

    def write(self):
        pass


class _CompareDataProcessorFactory:

        @staticmethod
        def create(view_type, summary_views, baseline, delta):
            if view_type == ViewAggregationLevel.SYSTEM:
                return SystemViewComparer(summary_views, baseline, delta)
            elif view_type == ViewAggregationLevel.SOCKET:
                return SocketViewComparer(summary_views, baseline, delta)
            else:
                raise ValueError(f'Invalid view type: {view_type}')
