import logging
from typing import Optional, List, Union, Tuple

import pandas as pd

from .base import BaseFeatureGenerator


logging.basicConfig(level=20, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


TimeStamp = Union[List[float], Tuple[float, float]]


class SegmentIdentifierFeatureGenerator(BaseFeatureGenerator):
    def get_nlp_data(
        self,
        previous_clip_timestamps: TimeStamp,
        current_clip_timestamps: TimeStamp,
        next_clip_timestamps: TimeStamp,
    ) -> Optional[pd.DataFrame]:
        data = {}
        try:
            logger.info("Getting nlp features.")
            current_sentence_info = self.nlp.get_sentence(current_clip_timestamps[0], current_clip_timestamps[
                1]) if current_clip_timestamps is not None else ("", 0)
            current_sentence = current_sentence_info[0]

            if current_sentence:
                next_sentence_info = self.nlp.get_sentence(next_clip_timestamps[0], next_clip_timestamps[
                    1]) if next_clip_timestamps is not None else ("", False)
                next_sentence = next_sentence_info[0]

                previous_sentence_info = self.nlp.get_sentence(previous_clip_timestamps[0], previous_clip_timestamps[
                    1]) if previous_clip_timestamps is not None else ("", False)

                previous_sentence = previous_sentence_info[0]

                # Keywords
                previous_keywords = self.nlp.bert_keyword_extraction(previous_sentence)
                current_keywords = self.nlp.bert_keyword_extraction(current_sentence)
                next_keywords = self.nlp.bert_keyword_extraction(next_sentence)
                previous_keywords_list = list(list(zip(*previous_keywords))[0]) if previous_keywords != [] else []
                current_keywords_list = list(list(zip(*current_keywords))[0]) if current_keywords != [] else []
                next_keywords_list = list(list(zip(*next_keywords))[0]) if next_keywords != [] else []

                # Punctuation
                current_ends_with_punctuation = current_sentence_info[1]

                # Correlations with previous
                sentence_correlation_with_previous = self.nlp.sentence_correlation(current_sentence,
                                                                                   previous_sentence)
                keywords_correlation_with_previous = self.nlp.keywords_correlation(current_keywords_list,
                                                                                   previous_keywords_list)

                # Correlation with next
                sentence_correlation_with_next = self.nlp.sentence_correlation(current_sentence, next_sentence)
                keywords_correlation_with_next = (
                    self.nlp.keywords_correlation(current_keywords_list, next_keywords_list)
                )

                # Fanboys
                starts_with_fanboys = self.nlp.starts_with_fanboys_conjunction(current_sentence)

                data = {
                    'ends_with_punctuation': current_ends_with_punctuation,
                    "sentence_correlation_with_previous": sentence_correlation_with_previous,
                    "keywords_correlation_with_previous": keywords_correlation_with_previous,
                    "sentence_correlation_with_next": sentence_correlation_with_next,
                    "keywords_correlation_with_next": keywords_correlation_with_next,
                    "starts_with_fanboys": starts_with_fanboys,
                }

            if data:
                return pd.DataFrame(data, index=[0])

        except Exception as error:
            logger.error(f"Unexpected error while getting nlp features for Segment identifier model: {error}")
            raise error

    def get_od_data(
        self,
        previous_clip_timestamps: TimeStamp,
        current_clip_timestamps: TimeStamp,
        next_clip_timestamps: TimeStamp,
        combined_categories: Optional[list] = None,
    ) -> Optional[pd.DataFrame]:
        data = {}
        try:
            current_frames = self.od.get_frames_ids(current_clip_timestamps[0], current_clip_timestamps[
                1]) if current_clip_timestamps is not None else None
            if current_frames:
                logger.info("Getting od features.")
                previous_frames = self.od.get_frames_ids(previous_clip_timestamps[0], previous_clip_timestamps[
                    1]) if previous_clip_timestamps is not None else None
                next_frames = self.od.get_frames_ids(next_clip_timestamps[0], next_clip_timestamps[
                    1]) if next_clip_timestamps is not None else None

                # Calculate number of objects appearing
                previous_number_of_objects = self.od.get_number_of_objects(previous_frames, combined_categories)
                current_number_of_objects = self.od.get_number_of_objects(current_frames, combined_categories)
                next_number_of_objects = self.od.get_number_of_objects(next_frames, combined_categories)

                # Number of state changes on each segment
                previous_number_of_state_changes = self.od.get_number_of_state_changes(previous_frames,
                                                                                       combined_categories)
                current_number_of_state_changes = self.od.get_number_of_state_changes(current_frames,
                                                                                      combined_categories)
                next_number_of_state_changes = self.od.get_number_of_state_changes(next_frames, combined_categories)

                # Check if starts with state changes
                # i.e.: if an object enters in scene or exits
                starts_or_ends_with_state_changes = self.od.starts_or_ends_with_state_changes(current_frames,
                                                                                              combined_categories)

                # Correlation within the objects between segments
                n_object_correlation_with_previous = self.od.calculate_correlation(previous_number_of_objects,
                                                                                   current_number_of_objects)
                n_object_correlation_with_next = self.od.calculate_correlation(current_number_of_objects,
                                                                               next_number_of_objects)

                # Correlation of the changes between segments
                # i.e.: if an object exits from a segment and enters the following
                n_changes_correlation_with_previous = self.od.calculate_correlation(
                    previous_number_of_state_changes, current_number_of_state_changes)
                n_changes_correlation_with_next = self.od.calculate_correlation(current_number_of_state_changes,
                                                                                next_number_of_state_changes)

                data = {
                    "n_object_correlation_with_previous": n_object_correlation_with_previous,
                    "n_object_correlation_with_next": n_object_correlation_with_next,
                    "n_changes_correlation_with_previous": n_changes_correlation_with_previous,
                    "n_changes_correlation_with_next": n_changes_correlation_with_next,
                }
                data.update(starts_or_ends_with_state_changes)

            # Create DataFrame
            if data:
                return pd.DataFrame(data, index=[0])

        except Exception as error:
            logger.error(f"Unexpected error while getting od features: {error}")
            raise error

    def get_features_for_single_clip(
        self,
        previous_clip_timestamps: TimeStamp,
        current_clip_timestamps: TimeStamp,
        next_clip_timestamps: TimeStamp,
        combined_categories_chosen: Optional[list] = None,
    ) -> Optional[pd.DataFrame]:
        combined_categories_chosen = combined_categories_chosen or []

        try:
            logger.info(f"Getting features for clip: {current_clip_timestamps} on Segment Selection")
            nlp = self.get_nlp_data(previous_clip_timestamps, current_clip_timestamps, next_clip_timestamps)
            od = self.get_od_data(previous_clip_timestamps, current_clip_timestamps, next_clip_timestamps,
                                  combined_categories_chosen)
            if nlp is not None and od is not None:
                return pd.concat([nlp, od], axis=1, join="inner")

        except Exception as error:
            logger.error(f"Unexpected error in main function get_features: {error}")
            raise error

    def get_features_for_video_file(self, segments: List[TimeStamp]) -> List[dict]:
        try:
            results = []

            for i, segment in enumerate(segments):
                start = float(segment["START"]) / 1000
                end = float(segment["END"]) / 1000
                current_segment = [start, end]
                previous_segment = segments[i - 1] if i != 0 else None
                next_segment = segments[i + 1] if i < len(segments) - 1 else None
                features = self.get_features_for_single_clip(previous_segment, current_segment, next_segment)

                if not features.empty:
                    features = features.to_json(orient="records")
                    results.append(
                        {
                            "segment": current_segment,
                            "features": features,
                            "video_clip_id": segment["video_clip_id"],
                        }
                    )

            return results
        except Exception as error:
            logger.error(f"Error retrieving features: {error}")
            raise error
