import boto3
import os
import string
import re
import nltk
import logging

from collections import defaultdict
from typing import List, Dict, Optional, Iterable, Tuple

from ..constants import NUMBERS_AS_WORDS

from keybert import KeyBERT
from nltk.corpus import stopwords, wordnet
from nltk.tokenize import word_tokenize
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer

nltk.download("stopwords")
nltk.download("wordnet")
nltk.download("punkt")

logging.basicConfig(
    level=20,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

s3 = boto3.client("s3")


class NLP:
    def __init__(self, transcription_json: dict):
        self.transcription = transcription_json
        self.sid_obj = SentimentIntensityAnalyzer()
        self.kw_model = KeyBERT()
        self.keyphrase_ngram_range = (1, 2)
        self.stop_words = "english"
        self.highlight = False
        self.top_n = 15

        self.__buckets: Optional[dict] = None
        self.__brands: Optional[List[str]] = None

    @property
    def buckets(self) -> dict:
        if self.__buckets is None:
            self.__buckets = self.__get_buckets()

        return self.__buckets

    @property
    def brands(self) -> List[str]:
        if self.__brands is None:
            self.__brands = self.__get_brands()

        return self.__brands

    @staticmethod
    def __get_s3_object_from_environment_bucket(environment_variable: str, key: str) -> str:
        bucket_name = os.getenv(environment_variable)
        if not bucket_name:
            raise ValueError(f"{environment_variable} is not set")

        return s3.get_object(Bucket=bucket_name, Key=key)["Body"].read().decode("utf-8")

    def __get_buckets(self) -> dict:
        buckets_and_common_words = self.__get_s3_object_from_environment_bucket(
            "UTILITY_BUCKET_NAME", "features_names_and_lists.csv"
        )

        buckets = defaultdict(list)
        for line in buckets_and_common_words.split("\n")[1:]:
            word, category = line.split(",")
            buckets[category].append(word)

        return buckets

    def __get_brands(self) -> List[str]:
        prefix = os.getenv("CUSTOM_VOCABULARY_OBJECT_PREFIX", "Trendio")
        vocabulary = self.__get_s3_object_from_environment_bucket(
            "UTILITY_BUCKET_NAME", f"{prefix}Brand.txt"
        )

        return [
            line.split("\t")[-1]
            for line in vocabulary.split("\n")[1:]
        ]

    @staticmethod
    def __clean_keywords(keywords):

        cleaned_keywords = []

        for keyword in keywords:
            cleaned_keyword = keyword[1].translate(str.maketrans("", "", string.punctuation))
            if cleaned_keyword != "":
                cleaned_keywords.append((keyword[0], cleaned_keyword))

        return list(set(cleaned_keywords))

    def get_sentence(self, start_time: float, end_time: float) -> Optional[Tuple[str, bool]]:
        try:
            sentence = ""
            ends_with_punctuation = False
            for item in self.transcription["results"]["items"]:
                item_type = item["type"]

                if item_type == "punctuation":

                    sentence += item["alternatives"][0]["content"] + " "
                    ends_with_punctuation = True

                else:
                    item_start = float(item["start_time"])
                    item_end = float(item["end_time"])

                    if item_start >= start_time and item_end <= end_time:
                        word = item["alternatives"][0]["content"]
                        sentence += word + " "
                        ends_with_punctuation = False

                    elif item_end > end_time:
                        break

            sentence = sentence.strip()
            sentence = re.sub(r"^\W+", "", sentence)
            sentence = re.sub(r"\s+([^\w\s]+)", r"\1", sentence)

            return sentence, ends_with_punctuation

        except Exception as error:
            logger.error(f"Error occurred in nlp method get_sentence: {error}")

    def bert_keyword_extraction(self, sentence: str) -> Optional[List[Iterable[Tuple[str, float]]]]:
        """Keyword extraction with BERT"""
        try:
            return self.kw_model.extract_keywords(
                sentence,
                keyphrase_ngram_range=self.keyphrase_ngram_range,
                stop_words=self.stop_words,
                highlight=self.highlight,
                top_n=self.top_n
            )

        except Exception as error:
            logger.error(f"Error occurred in nlp method bert_keyword_extraction: {error}")

    def sentiment_scores(self, sentence: str) -> Optional[Dict[str, float]]:
        """Sentiment analysis. Detect an average sentiment of a sentence."""
        try:
            # Create a SentimentIntensityAnalyzer object.
            # polarity_scores method of SentimentIntensityAnalyzer object gives a sentiment dictionary
            # which contains pos, neg, neu, and compound scores.
            return self.sid_obj.polarity_scores(sentence)

        except Exception as e:
            logger.error("Error occurred in nlp method sentiment_scores: %s", str(e))

    @staticmethod
    def calculate_thematic_coherence(keywords: list) -> Optional[float]:
        try:
            # Calculate the average similarity between all keyword combinations
            similarities = []
            for i in range(len(keywords)):
                for j in range(i + 1, len(keywords)):
                    synset1 = wordnet.synsets(keywords[i])
                    synset2 = wordnet.synsets(keywords[j])
                    if synset1 and synset2:
                        similarity = synset1[0].path_similarity(synset2[0])
                        if similarity:
                            similarities.append(similarity)

            # Calculate the measure of thematic coherence by averaging the similarities
            return sum(similarities) / len(similarities) if similarities else 0.0

        except Exception as e:
            logger.error("Error occurred in nlp method calculate_thematic_coherence: %s", str(e))

    @staticmethod
    def lexical_diversity(sentence: str) -> Optional[float]:
        try:
            words = sentence.split()

            total_words = len(words)
            unique_words = len(set(words))

            return unique_words / total_words if total_words > 0 else 0

        except Exception as e:
            logger.error("Error occurred in nlp method lexical_diversity: %s", str(e))

    @staticmethod
    def sentence_correlation(sentence1: str, sentence2: str) -> Optional[float]:
        """Sentence correlation. Checks the correlation between two given sentences using cosine similarity."""
        try:
            similarity = 0.0

            if sentence1 != "" and sentence2 != "":
                # Tokenize and remove stop words from the sentences
                stop_words = set(stopwords.words("english"))
                tokens1 = [word.lower() for word in word_tokenize(sentence1) if word.lower() not in stop_words]
                tokens2 = [word.lower() for word in word_tokenize(sentence2) if word.lower() not in stop_words]

                # Create a frequency distribution for each sentence
                fdist1 = nltk.FreqDist(tokens1)
                fdist2 = nltk.FreqDist(tokens2)

                # Create a set of unique words in the sentences
                unique_words = set(tokens1 + tokens2)

                # Create a frequency vector for each sentence
                freq_vector1 = [fdist1[word] for word in unique_words]
                freq_vector2 = [fdist2[word] for word in unique_words]

                # Calculate the cosine similarity between the frequency vectors
                dot_product = sum([freq_vector1[i] * freq_vector2[i] for i in range(len(unique_words))])
                magnitude1 = sum([freq_vector1[i] ** 2 for i in range(len(unique_words))]) ** 0.5
                magnitude2 = sum([freq_vector2[i] ** 2 for i in range(len(unique_words))]) ** 0.5

                similarity = 0.0 if magnitude1 == 0 or magnitude2 == 0 else dot_product / (magnitude1 * magnitude2)

            return similarity

        except Exception as e:
            logger.error("Error occurred in nlp method sentence_correlation: %s", str(e))

    @staticmethod
    def keywords_correlation(keywords1: List[str], keywords2: List[str]) -> Optional[float]:
        """Sentence correlation. Checks the correlation between two given sentences using cosine similarity."""
        try:
            if keywords1 == [] or keywords2 == []:
                return 0

            # Create a set of unique words in the sentences
            all_keywords = list(set(keywords1 + keywords2))

            vector1 = [keywords1.count(keyword) for keyword in all_keywords]
            vector2 = [keywords2.count(keyword) for keyword in all_keywords]

            # Calculate the cosine similarity between the frequency vectors
            dot_product = sum(vector1[i] * vector2[i] for i in range(len(all_keywords)))
            magnitude1 = sum(vector1[i] ** 2 for i in range(len(all_keywords)))
            magnitude2 = sum(vector2[i] ** 2 for i in range(len(all_keywords)))
            similarity = dot_product / ((magnitude1 * magnitude2) ** 0.5)

            return similarity

        except Exception as e:
            logger.error("Error occurred in nlp method keywords_correlation: %s", str(e))

    def get_vertical_syntactic_dictionaries_mentions(self, sentence: str) -> Dict[str, bool]:
        try:
            words = sentence.split()

            return {
                f"{category.lower().replace(' ', '_')}_mentioned": any(word in keywords for word in words)
                for category, keywords in self.buckets.items()
            }

        except Exception as e:
            logger.error(f"Error occurred in nlp method get_features_mention: {e}")
            return {}

    def is_brand_mentioned(self, sentence: str) -> Optional[bool]:
        try:
            words = sentence.split()
            return any(word in self.brands for word in words)

        except Exception as e:
            logger.error(f"Error occurred in nlp method is_brand_mentions: {e}")

    @staticmethod
    def has_numbers(sentence: str) -> Optional[bool]:
        try:
            words = sentence.lower().split()
            numbers = list(string.digits) + NUMBERS_AS_WORDS
            return any(number in words for number in numbers)

        except Exception as e:
            logger.error(f"Error occurred in nlp method has_numbers: {e}")

    @staticmethod
    def percentage_mentioned(sentence: str) -> Optional[bool]:
        try:
            return "%" in sentence

        except Exception as e:
            logger.error(f"Error occurred in nlp method percentage_mentioned: {e}")

    @staticmethod
    def price_mentioned(sentence: str) -> Optional[bool]:
        try:
            currency_symbols = ["$", "€", "£", "¥"]
            return any(symbol in sentence for symbol in currency_symbols)

        except Exception as e:
            logger.error(f"Error occurred in nlp method price_mentioned: {e}")

    @staticmethod
    def starts_with_fanboys_conjunction(sentence: str) -> Optional[bool]:
        try:
            fanboys_conjunctions = ["for", "and", "nor", "but", "or", "yet", "so"]
            sentence_lower = sentence.lower()
            return any(sentence_lower.startswith(conjunction) for conjunction in fanboys_conjunctions)

        except Exception as error:
            logger.error(f"Error occurred in nlp method starts_with_fanboys_conjunction: {error}")

    def is_sentence_cut(self, start_time, end_time) -> bool:
        sentence, ends_with_punctuation = self.get_sentence(start_time, end_time)
        starts_with_fanboys = self.starts_with_fanboys_conjunction(sentence)
        cut_pattern = re.compile(r"[.!?]")
        re_sentence = sentence.rstrip(".!?")

        return bool(cut_pattern.search(re_sentence) or not ends_with_punctuation or starts_with_fanboys)
