from typing import Optional

import numpy as np
import statistics


class OD:
    def __init__(self, od_json: dict):
        self.object_detection = od_json
        self.combined_categories = {
            "comb_0": [2, 3],
            "comb_1": [7, 8],
            "comb_2": [9, 10, 11],
            "comb_3": [12, 13, 14],
            "comb_4": [15, 16]
        }
        self.banned_categories = [0, 1, 20, 21]
        self.application_categories = [3, 4, 11, 14]
        self.consistency_threshold = 15
        self.consecutive_threshold = 5
        self.not_product_related_categories = [0, 1, 2, 20, 21]

    def clean_object_list(self, combined_categories_chosen: Optional[list] = None):
        combined_categories_chosen = combined_categories_chosen or []

        cleaned_object_list = self.object_detection.copy()

        # Remove banned categories from the object list
        for cat in self.banned_categories:
            cleaned_object_list.pop(f"obj_{cat}", None)
            cleaned_object_list.pop(f"{cat}_state_change", None)

        if combined_categories_chosen:
            # Remove objects based on the combined categories chosen
            for cat in combined_categories_chosen:
                categs_to_delete = self.combined_categories[cat]
                for cat_to_delete in categs_to_delete:
                    cleaned_object_list.pop(f"obj_{cat_to_delete}", None)
                    cleaned_object_list.pop(f"{cat_to_delete}_state_change", None)
        else:
            # Remove objects with names containing "comb" if no combined categories are chosen
            cleaned_object_list = {k: v for k, v in cleaned_object_list.items() if "comb" not in k}

        return cleaned_object_list

    def get_frames_ids(self, start_time: float, end_time: float) -> list:
        timestamps = self.object_detection["timestamp"]
        id_frames_list = []
        for frame_id, timestamp in timestamps.items():
            if start_time <= timestamp <= end_time:
                id_frames_list.append(frame_id)
        return id_frames_list

    def get_objects(
        self,
        object_type: int,
        frames_ids: Optional[list],
        combined_categories_chosen: Optional[list] = None,
    ) -> dict:
        # Determine the search term based on the object type
        if object_type == 0:
            search_term = "obj_"
            pattern = "obj_{0}"
        elif object_type == 1:
            search_term = "_state_change"
            pattern = "{0}_state_change"
        else:
            # Return an empty dictionary if the object type is invalid
            return {}

        if frames_ids is None:
            return {}

        objects_list = {}
        for object_id, frames_data in self.object_detection.items():
            if search_term in object_id:
                # Count the occurrences of the object in the specified frames
                object_count = sum(frame_id in frames_ids and value == 1 for frame_id, value in frames_data.items())
                objects_list[object_id] = object_count

        if combined_categories_chosen:
            # Remove objects based on the combined categories chosen
            for cat in combined_categories_chosen:
                obj_ids = self.combined_categories.get(str(cat))
                if obj_ids is not None:
                    objects_list = {
                        k: v
                        for k, v in objects_list.items()
                        if not any(k == pattern.format(obj_id) for obj_id in obj_ids)
                    }
        else:
            # Filter out objects with names containing "comb" if no combined categories are chosen
            objects_list = {k: v for k, v in objects_list.items() if "comb" not in k}

        # Remove banned categories from the objects list
        for cat in self.banned_categories:
            objects_list = {
                k: v
                for k, v in objects_list.items()
                if not k == pattern.format(cat)
            }

        return objects_list

    def get_number_of_objects(
        self,
        frames_ids: list,
        combined_categories_chosen: Optional[list] = None
    ) -> dict:
        return self.get_objects(0, frames_ids, combined_categories_chosen)

    def get_number_of_state_changes(
        self,
        frames_ids: list,
        combined_categories_chosen: Optional[list] = None
    ) -> dict:
        return self.get_objects(1, frames_ids, combined_categories_chosen)

    def calculate_correlation(self, dict1: dict, dict2: dict) -> float:
        if not dict1 or not dict2:
            return 0

        values1 = list(dict1.values())
        values2 = list(dict2.values())

        if len(values1) <= 1 or len(values2) <= 1:
            return 0

        std_deviation1 = statistics.stdev(values1)
        std_deviation2 = statistics.stdev(values2)

        if std_deviation1 > 0 and std_deviation2 > 0:
            return np.corrcoef(values1, values2)[0, 1]

        return 0.0

    def get_state_changes_deviation(
        self,
        frames_ids: list,
        combined_categories_chosen: Optional[list] = None
    ) -> float:
        state_changes_std_devs = []
        objects_to_analyze = self.clean_object_list(combined_categories_chosen)

        # Calculate the standard deviation of state changes for each object
        for object_id, object_estado_data in objects_to_analyze.items():
            if object_id.endswith("_state_change"):
                # Get state changes for the object within the specified time interval
                state_changes = [object_estado_data[frame] for frame in sorted(object_estado_data) if frame in frames_ids]

                # Calculate the standard deviation of state changes
                if len(state_changes) > 1:
                    std_deviation = statistics.stdev(state_changes)
                    state_changes_std_devs.append(std_deviation)

        # Calculate the mean of the standard deviations
        if state_changes_std_devs:
            return statistics.mean(state_changes_std_devs)

        return 0

    def starts_or_ends_with_state_changes(
        self,
        frames_ids: list,
        combined_categories_chosen: Optional[list] = None
    ) -> dict:
        objects_to_analyze = self.clean_object_list(combined_categories_chosen)
        starts_or_ends_with_state_changes = {}
        for object_id, object_estado_data in objects_to_analyze.items():
            if object_id.endswith("_state_change"):
                state_changes = [object_estado_data[frame] for frame in sorted(object_estado_data) if frame in frames_ids]
                starts_or_ends_with_state_changes[f"{object_id}_at_start"] = state_changes[0]
                starts_or_ends_with_state_changes[f"{object_id}_at_end"] = state_changes[-1]

        return starts_or_ends_with_state_changes

    def is_product_applied(
        self,
        frames_ids: list,
        combined_categories_chosen: Optional[list] = None
    ) -> bool:
        objects_to_analyze = self.clean_object_list(combined_categories_chosen)
        for object_id, object_data in objects_to_analyze.items():
            if "obj_" in object_id:
                object_id_int = int(object_id.replace("obj_", ""))
                if object_id_int in self.application_categories:
                    consistency_counter = sum(
                        1 for frame in sorted(object_data) if frame in frames_ids and object_data[frame] != 0.0)
                    if consistency_counter >= self.consistency_threshold:
                        return True

        return False

    def percentage_product_is_visible(
        self,
        frames_ids: list,
        combined_categories_chosen: Optional[list] = None,
    ) -> float:
        objects_to_analyze = self.clean_object_list(combined_categories_chosen)
        complete_apparition_list = []
        for object_id, object_data in objects_to_analyze.items():
            combined_list = []
            if "obj_" in object_id:
                object_id_int = int(object_id.replace("obj_", ""))
                if object_id_int not in self.not_product_related_categories:
                    for frame in sorted(object_data):
                        if frame in frames_ids:
                            appears = 1 if object_data[frame] != 0.0 else 0
                            combined_list.append(appears)
                    if not complete_apparition_list:
                        complete_apparition_list = combined_list.copy()
                    else:
                        for i in range(len(complete_apparition_list)):
                            if complete_apparition_list[i] == 0 and combined_list[i] == 1:
                                complete_apparition_list[i] = 1

        return complete_apparition_list.count(1) / len(complete_apparition_list) if complete_apparition_list else 0

    def is_change_of_aspect_ratio(self, frames_ids: list, od_object: dict) -> bool:
        change = 0
        consecutive = 0
        for i in range(len(frames_ids) - 1):
            current_frame = frames_ids[i]
            next_frame = frames_ids[i + 1]

            current_value = od_object[current_frame]
            next_value = od_object[next_frame]

            change_percentage = abs((next_value - current_value) / current_value) * 100 if current_value != 0 else 0

            if change_percentage >= 50:  # Adjust the threshold as needed
                change += 1
            elif change_percentage < 50 and change == 1:
                consecutive += 1
                if consecutive >= self.consecutive_threshold:
                    return True

        return False

    def is_focus_on_face(self, frames_ids: list) -> bool:
        face_on_screen = self.object_detection["face_on_screen"]
        return self.is_change_of_aspect_ratio(frames_ids, face_on_screen)

    def is_focus_on_product(self, frames_ids: list) -> bool:
        face_on_screen = self.object_detection["product_on_screen"]
        return self.is_change_of_aspect_ratio(frames_ids, face_on_screen)
