import re
import nltk
from collections import Counter
from cleantweet import CleanTweet
from nltk.tokenize import word_tokenize, sent_tokenize
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
nltk.download('punkt_tab', quiet=True)


class InferenceTweet(CleanTweet):
    """
    Class for text summarization and inference tasks
    """
    
    def _fbs(self, amount_of_sentences: int = 5):
        """
        Frequency-Based Summarization (protected method)
        :param amount_of_sentences: number of sentences to extract
        :return: list of top sentences
        """
        sentences = sent_tokenize(self.clean())
        words = word_tokenize(self.clean().lower())
        word_frequency = Counter(words)
        sentence_scores = {sent: sum(word_frequency[word] for word in word_tokenize(sent.lower()))
                           for sent in sentences}
        ranked_sentences = sorted(sentence_scores, key=sentence_scores.get, reverse=True)
        return ranked_sentences[:amount_of_sentences]

    def frequency_based_summarization(self, amount_of_sentences: int = 5):
        """
        Frequency-Based Summarization - extracts most important sentences based on word frequency
        :param amount_of_sentences: number of sentences to extract for summary, default is 5
        :return: a list of summary sentences
        """
        summary_sentences = self._fbs(amount_of_sentences)
        return summary_sentences

    def textrank_summarization(self, amount_of_sentences: int = 5):
        """
        Graph-Based Summarization using TextRank algorithm
        :param amount_of_sentences: number of sentences to extract for summary, default is 5
        :return: a list of summary sentences
        """
        sentences = sent_tokenize(self.clean())
        if len(sentences) < amount_of_sentences:
            return sentences
        
        # Build similarity matrix
        similarity_matrix = self._build_similarity_matrix(sentences)
        
        # Calculate sentence scores using PageRank-like algorithm
        scores = self._calculate_textrank_scores(similarity_matrix)
        
        # Rank sentences by score
        ranked_sentences = sorted(zip(sentences, scores), key=lambda x: x[1], reverse=True)
        
        # Return top sentences in original order
        top_indices = sorted([sentences.index(sent) for sent, _ in ranked_sentences[:amount_of_sentences]])
        return [sentences[i] for i in top_indices]

    def _build_similarity_matrix(self, sentences):
        """
        Build similarity matrix between sentences
        :param sentences: list of sentences
        :return: similarity matrix
        """
        n = len(sentences)
        similarity_matrix = [[0.0] * n for _ in range(n)]
        
        for i in range(n):
            for j in range(n):
                if i != j:
                    similarity_matrix[i][j] = self._sentence_similarity(sentences[i], sentences[j])
        
        return similarity_matrix

    def _sentence_similarity(self, sent1, sent2):
        """
        Calculate similarity between two sentences using word overlap
        :param sent1: first sentence
        :param sent2: second sentence
        :return: similarity score
        """
        words1 = set(word_tokenize(sent1.lower()))
        words2 = set(word_tokenize(sent2.lower()))
        
        if len(words1) == 0 or len(words2) == 0:
            return 0.0
        
        intersection = words1.intersection(words2)
        union = words1.union(words2)
        
        return len(intersection) / len(union) if len(union) > 0 else 0.0

    def _calculate_textrank_scores(self, similarity_matrix, damping=0.85, max_iter=100, tol=1e-6):
        """
        Calculate TextRank scores using iterative algorithm
        :param similarity_matrix: similarity matrix between sentences
        :param damping: damping factor (default 0.85)
        :param max_iter: maximum iterations
        :param tol: tolerance for convergence
        :return: list of scores
        """
        n = len(similarity_matrix)
        scores = [1.0] * n
        
        # Normalize similarity matrix
        normalized_matrix = []
        for i in range(n):
            row_sum = sum(similarity_matrix[i])
            if row_sum > 0:
                normalized_matrix.append([sim / row_sum for sim in similarity_matrix[i]])
            else:
                normalized_matrix.append([0.0] * n)
        
        for _ in range(max_iter):
            prev_scores = scores[:]
            for i in range(n):
                score = 1 - damping
                for j in range(n):
                    if normalized_matrix[j][i] > 0:
                        score += damping * normalized_matrix[j][i] * prev_scores[j]
                scores[i] = score
            
            if sum(abs(scores[i] - prev_scores[i]) for i in range(n)) < tol:
                break
        
        return scores

    def summarize(self, method: str = 'frequency', amount_of_sentences: int = 5):
        """
        Main summarization method that supports different algorithms
        :param method: summarization method - 'frequency' or 'textrank', default is 'frequency'
        :param amount_of_sentences: number of sentences in summary, default is 5
        :return: summary as a string
        """
        if method.lower() == 'frequency':
            summary_sentences = self.frequency_based_summarization(amount_of_sentences)
        elif method.lower() == 'textrank':
            summary_sentences = self.textrank_summarization(amount_of_sentences)
        else:
            raise ValueError("Method must be 'frequency' or 'textrank'")
        
        return ' '.join(summary_sentences)

