from __future__ import annotations import bz2 import re import warnings from typing import Literal import pandas as pd from joblib import Memory from nltk.stem import WordNetLemmatizer from sklearn.base import BaseEstimator, TransformerMixin from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.pipeline import Pipeline from app.constants import ( AMAZONREVIEWS_PATH, AMAZONREVIEWS_URL, CACHE_DIR, EMOTICON_MAP, IMDB50K_PATH, IMDB50K_URL, SENTIMENT140_PATH, SENTIMENT140_URL, URL_REGEX, ) __all__ = ["load_data", "create_model", "train_model"] class TextCleaner(BaseEstimator, TransformerMixin): def __init__( self, *, replace_url: bool = True, replace_hashtag: bool = True, replace_emoticon: bool = True, replace_emoji: bool = True, lowercase: bool = True, character_threshold: int = 2, remove_special_characters: bool = True, remove_extra_spaces: bool = True, ): self.replace_url = replace_url self.replace_hashtag = replace_hashtag self.replace_emoticon = replace_emoticon self.replace_emoji = replace_emoji self.lowercase = lowercase self.character_threshold = character_threshold self.remove_special_characters = remove_special_characters self.remove_extra_spaces = remove_extra_spaces def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextCleaner: return self def transform(self, data: list[str], _labels: list[int] | None = None) -> list[str]: # Replace URLs, hashtags, emoticons, and emojis data = [re.sub(URL_REGEX, "URL", text) for text in data] if self.replace_url else data data = [re.sub(r"#\w+", "HASHTAG", text) for text in data] if self.replace_hashtag else data # Replace emoticons if self.replace_emoticon: for word, emoticons in EMOTICON_MAP.items(): for emoticon in emoticons: data = [text.replace(emoticon, f"EMOTE_{word}") for text in data] # Basic text cleaning data = [text.lower() for text in data] if self.lowercase else data # Lowercase threshold_pattern = re.compile(rf"\b\w{{1,{self.character_threshold}}}\b") data = ( [re.sub(threshold_pattern, "", text) for text in data] if self.character_threshold > 0 else data ) # Remove short words data = ( [re.sub(r"[^a-zA-Z0-9\s]", "", text) for text in data] if self.remove_special_characters else data ) # Remove special characters data = [re.sub(r"\s+", " ", text) for text in data] if self.remove_extra_spaces else data # Remove extra spaces # Remove leading and trailing whitespace return [text.strip() for text in data] class TextLemmatizer(BaseEstimator, TransformerMixin): def __init__(self): self.lemmatizer = WordNetLemmatizer() def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextLemmatizer: return self def transform(self, data: list[str], _labels: list[int] | None = None) -> list[str]: return [self.lemmatizer.lemmatize(text) for text in data] def load_sentiment140(include_neutral: bool = False) -> tuple[list[str], list[int]]: """Load the sentiment140 dataset and make it suitable for use. Args: include_neutral: Whether to include neutral sentiment Returns: Text and label data Raises: FileNotFoundError: If the dataset is not found """ # Check if the dataset exists if not SENTIMENT140_PATH.exists(): msg = ( f"Sentiment140 dataset not found at: '{SENTIMENT140_PATH}'\n" "Please download the dataset from:\n" f"{SENTIMENT140_URL}" ) raise FileNotFoundError(msg) # Load the dataset data = pd.read_csv( SENTIMENT140_PATH, encoding="ISO-8859-1", names=[ "target", # 0 = negative, 2 = neutral, 4 = positive "id", # The id of the tweet "date", # The date of the tweet "flag", # The query, NO_QUERY if not present "user", # The user that tweeted "text", # The text of the tweet ], ) # Ignore rows with neutral sentiment if not include_neutral: data = data[data["target"] != 2] # Map sentiment values data["sentiment"] = data["target"].map( { 0: 0, # Negative 4: 1, # Positive 2: 2, # Neutral }, ) # Return as lists return data["text"].tolist(), data["sentiment"].tolist() def load_amazonreviews(merge: bool = True) -> tuple[list[str], list[int]]: """Load the amazonreviews dataset and make it suitable for use. Args: merge: Whether to merge the test and train datasets (otherwise ignore test) Returns: Text and label data Raises: FileNotFoundError: If the dataset is not found """ # Check if the dataset exists test_exists = AMAZONREVIEWS_PATH[0].exists() or not merge train_exists = AMAZONREVIEWS_PATH[1].exists() if not (test_exists and train_exists): msg = ( f"Amazonreviews dataset not found at: '{AMAZONREVIEWS_PATH[0]}' and '{AMAZONREVIEWS_PATH[1]}'\n" "Please download the dataset from:\n" f"{AMAZONREVIEWS_URL}" ) raise FileNotFoundError(msg) # Load the datasets with bz2.BZ2File(AMAZONREVIEWS_PATH[1]) as train_file: train_data = [line.decode("utf-8") for line in train_file] test_data = [] if merge: with bz2.BZ2File(AMAZONREVIEWS_PATH[0]) as test_file: test_data = [line.decode("utf-8") for line in test_file] # Merge the datasets data = train_data + test_data # Split the data into labels and text labels, texts = zip(*(line.split(" ", 1) for line in data)) # Map sentiment values sentiments = [int(label.split("__label__")[1]) - 1 for label in labels] # Return as lists return texts, sentiments def load_imdb50k() -> tuple[list[str], list[int]]: """Load the imdb50k dataset and make it suitable for use. Returns: Text and label data Raises: FileNotFoundError: If the dataset is not found """ # Check if the dataset exists if not IMDB50K_PATH.exists(): msg = ( f"IMDB50K dataset not found at: '{IMDB50K_PATH}'\n" "Please download the dataset from:\n" f"{IMDB50K_URL}" ) # fmt: off raise FileNotFoundError(msg) # Load the dataset data = pd.read_csv(IMDB50K_PATH) # Map sentiment values data["sentiment"] = data["sentiment"].map( { "positive": 1, "negative": 0, }, ) # Return as lists return data["review"].tolist(), data["sentiment"].tolist() def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> tuple[list[str], list[int]]: """Load and preprocess the specified dataset. Args: dataset: Dataset to load Returns: Text and label data Raises: ValueError: If the dataset is not recognized """ match dataset: case "sentiment140": return load_sentiment140(include_neutral=False) case "amazonreviews": return load_amazonreviews(merge=True) case "imdb50k": return load_imdb50k() case _: msg = f"Unknown dataset: {dataset}" raise ValueError(msg) def create_model( max_features: int, seed: int | None = None, ) -> Pipeline: """Create a sentiment analysis model. Args: max_features: Maximum number of features seed: Random seed (None for random seed) Returns: Untrained model """ return Pipeline( [ # Text preprocessing ("clean", TextCleaner()), ("lemma", TextLemmatizer()), # Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity) ("vectorize", CountVectorizer(stop_words="english", ngram_range=(1, 2), max_features=max_features)), ("tfidf", TfidfTransformer()), # Classifier ("clf", LogisticRegression(max_iter=1000, random_state=seed)), ], memory=Memory(CACHE_DIR, verbose=0), ) def train_model( model: Pipeline, text_data: list[str], label_data: list[int], seed: int = 42, ) -> float: """Train the sentiment analysis model. Args: model: Untrained model text_data: Text data label_data: Label data seed: Random seed (None for random seed) Returns: Accuracy score """ text_train, text_test, label_train, label_test = train_test_split( text_data, label_data, test_size=0.2, random_state=seed, ) with warnings.catch_warnings(): warnings.simplefilter("ignore") model.fit(text_train, label_train) return model.score(text_test, label_test)