{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sentiment Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import re\n",
    "from functools import cache\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data: pd.DataFrame = None  # TODO: load dataset\n",
    "stopwords: set[str] = None  # TODO: load stopwords"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Explore the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution\n",
    "_, ax = plt.subplots(figsize=(6, 4))\n",
    "data[\"sentiment\"].value_counts().plot(kind=\"bar\", ax=ax)\n",
    "ax.set_xticklabels([\"Negative\", \"Positive\"], rotation=0)\n",
    "ax.set_xlabel(\"Sentiment\")\n",
    "ax.grid(False)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@cache\n",
    "def extract_words(text: str) -> list[str]:\n",
    "    return re.findall(r\"(\\b[^\\s]+\\b)\", text.lower())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract words and count them\n",
    "words = data[\"text\"].apply(extract_words).explode()\n",
    "word_counts = words.value_counts().reset_index()\n",
    "word_counts.columns = [\"word\", \"count\"]\n",
    "word_counts.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the most common words\n",
    "_, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))\n",
    "\n",
    "sns.barplot(data=word_counts.head(10), x=\"count\", y=\"word\", ax=ax1)\n",
    "ax1.set_title(\"Most common words\")\n",
    "ax1.grid(False)\n",
    "ax1.tick_params(axis=\"x\", rotation=45)\n",
    "\n",
    "ax2.set_title(\"Most common words (excluding stopwords)\")\n",
    "sns.barplot(\n",
    "    data=word_counts[~word_counts[\"word\"].isin(stopwords)].head(10),\n",
    "    x=\"count\",\n",
    "    y=\"word\",\n",
    "    ax=ax2,\n",
    ")\n",
    "ax2.grid(False)\n",
    "ax2.tick_params(axis=\"x\", rotation=45)\n",
    "ax2.set_ylabel(\"\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Find best classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Find best hyperparameters"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}