{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "collapsed_sections": [ "JCeTELmaSD9R", "61vkNVxlzzyW" ], "machine_shape": "hm", "gpuType": "A100" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "ac5a633935c04013a0a850fe39975116": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_61e818edc44c44e8a3d8303eee8ff91c", "IPY_MODEL_672ea5d57f04454a8b96ab13232d5237", "IPY_MODEL_49dbac60cb814000a5aa59985466fd81" ], "layout": "IPY_MODEL_949c68abb0784a5b911a5e03d1f429ad" } }, "61e818edc44c44e8a3d8303eee8ff91c": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_e4ffaccaa9184c83837e36be0beb30f5", "placeholder": "​", "style": "IPY_MODEL_f55ed7599ef546e2a02a2d83beae41d7", "value": "Map: 100%" } }, "672ea5d57f04454a8b96ab13232d5237": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_65be8f1e2cc042a98003cb68a99e46d0", "max": 12528, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_946454de1bf543f8911aea42cd0d64dc", "value": 12528 } }, "49dbac60cb814000a5aa59985466fd81": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_7435c281bf8f4480b3542eb34d8826c9", "placeholder": "​", "style": "IPY_MODEL_b7189a93516242edb18e739331bcb1d7", "value": " 12528/12528 [00:04<00:00, 2504.94 examples/s]" } }, "949c68abb0784a5b911a5e03d1f429ad": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e4ffaccaa9184c83837e36be0beb30f5": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f55ed7599ef546e2a02a2d83beae41d7": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "65be8f1e2cc042a98003cb68a99e46d0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "946454de1bf543f8911aea42cd0d64dc": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "7435c281bf8f4480b3542eb34d8826c9": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b7189a93516242edb18e739331bcb1d7": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "8adfee87f17445159c1fdced536dc85c": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_071989bc2f3b4ea8a0ca7ea871844aca", "IPY_MODEL_caa8dc7da70c4a91940f72647c0855ed", "IPY_MODEL_389215a286de468eb666908e98ce8c5d" ], "layout": "IPY_MODEL_7bb192ab8d8d4a4d87958e35c78e4fd8" } }, "071989bc2f3b4ea8a0ca7ea871844aca": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c6b087846a6a4b8abe1bc04593e2edc9", "placeholder": "​", "style": "IPY_MODEL_371a00a75b9745c1bc50d412731a5f8a", "value": "Map: 100%" } }, "caa8dc7da70c4a91940f72647c0855ed": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_de656fc10762480aa127286dc10d0c95", "max": 3132, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_66e71a906bfc4f679ef2e7c3361c7cac", "value": 3132 } }, "389215a286de468eb666908e98ce8c5d": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_01c50a06ef2f45f7b4ceabb7e7583416", "placeholder": "​", "style": "IPY_MODEL_328167df4a454a7fa65f12d91de86b6b", "value": " 3132/3132 [00:01<00:00, 2501.72 examples/s]" } }, "7bb192ab8d8d4a4d87958e35c78e4fd8": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c6b087846a6a4b8abe1bc04593e2edc9": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "371a00a75b9745c1bc50d412731a5f8a": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "de656fc10762480aa127286dc10d0c95": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "66e71a906bfc4f679ef2e7c3361c7cac": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "01c50a06ef2f45f7b4ceabb7e7583416": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "328167df4a454a7fa65f12d91de86b6b": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "# Environment Setup" ], "metadata": { "id": "gPNw1aaGyjAl" } }, { "cell_type": "markdown", "source": [ "Install requirements" ], "metadata": { "id": "AHiYkD3MWL6H" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OABT3lobb-x4" }, "outputs": [], "source": [ "!pip install datasets\n", "!pip install transformers datasets torch\n", "!pip install evaluate accelerate transformers[torch] -U\n", "!pip install accelerate -U\n", "!pip install kaggle" ] }, { "cell_type": "markdown", "source": [ "Import essential libraries" ], "metadata": { "id": "rcxjr7gkWNq9" } }, { "cell_type": "code", "source": [ "import pandas as pd\n", "import ast\n", "import matplotlib.pyplot as plt\n", "import numpy as np" ], "metadata": { "id": "G1jOqd8AypnY" }, "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "source": [ "Important variables for reproducibility" ], "metadata": { "id": "mGa4k5BOrLZl" } }, { "cell_type": "code", "source": [ "# We will use this seed for all random functions to ensure that they are deterministic and reproducable\n", "RANDOM_SEED = 3109" ], "metadata": { "id": "PznmBW7WrLDK" }, "execution_count": 3, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Data Setup" ], "metadata": { "id": "JCeTELmaSD9R" } }, { "cell_type": "markdown", "source": [ "Import Dataset from Kaggle" ], "metadata": { "id": "MuHIubz0ZOhT" } }, { "cell_type": "code", "source": [ "from google.colab import files\n", "!pip install -q kaggle\n", "files.upload()\n", "!mkdir -p ~/.kaggle\n", "!cp kaggle.json ~/.kaggle/\n", "!chmod 600 /root/.kaggle/kaggle.json\n", "!kaggle datasets download -d rounakbanik/the-movies-dataset\n", "!ls\n", "!mkdir data\n", "!unzip the-movies-dataset.zip -d data" ], "metadata": { "id": "xxhY4MyKSF06", "colab": { "base_uri": "https://localhost:8080/", "height": 281 }, "outputId": "ee16925e-ccec-479b-90c5-19ddd22f3c30" }, "execution_count": 4, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " Upload widget is only available when the cell has been executed in the\n", " current browser session. Please rerun this cell to enable.\n", " \n", " " ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "Saving kaggle.json to kaggle.json\n", "Downloading the-movies-dataset.zip to /content\n", "100% 228M/228M [00:01<00:00, 210MB/s]\n", "100% 228M/228M [00:01<00:00, 189MB/s]\n", "kaggle.json sample_data the-movies-dataset.zip\n", "Archive: the-movies-dataset.zip\n", " inflating: data/credits.csv \n", " inflating: data/keywords.csv \n", " inflating: data/links.csv \n", " inflating: data/links_small.csv \n", " inflating: data/movies_metadata.csv \n", " inflating: data/ratings.csv \n", " inflating: data/ratings_small.csv \n" ] } ] }, { "cell_type": "markdown", "source": [ "# Data Integration" ], "metadata": { "id": "a_Rubw0-VO0T" } }, { "cell_type": "markdown", "source": [ "Create a new CSV with 'genres' from movies_metadata.csv and 'keywords' from keywords.csv, joined together on 'id'" ], "metadata": { "id": "hkzC-cvvV30G" } }, { "cell_type": "code", "source": [ "# Read in 'movies_metadata' and 'keywords'\n", "df1 = pd.read_csv('data/movies_metadata.csv' , low_memory=False)\n", "df2 = pd.read_csv('data/keywords.csv' , low_memory=False)\n", "\n", "# Convert the 'id' columns to strings\n", "df1['id'] = df1['id'].astype(str)\n", "df2['id'] = df2['id'].astype(str)\n", "\n", "# Merge the csv files on 'id' matching\n", "merged_df = pd.merge(df1, df2, on='id')\n", "# Select only the 'genres' and 'keywords' columns of the newly merged csv\n", "merged_df = merged_df.loc[:,['genres', 'keywords']]" ], "metadata": { "id": "vX8yN5TCVYgq" }, "execution_count": 52, "outputs": [] }, { "cell_type": "markdown", "source": [ "Convert the dictionary contents of both 'genres' and 'keywords' columns into tuples" ], "metadata": { "id": "xXRnYB0SZY3W" } }, { "cell_type": "code", "source": [ "evaluated_df = merged_df\n", "\n", "# Apply 'literal_eval' from ast library to each entry (converting string representation of dictionaries into python dictionaries)\n", "evaluated_df['genres'] = evaluated_df['genres'].apply(ast.literal_eval)\n", "evaluated_df['keywords'] = evaluated_df['keywords'].apply(ast.literal_eval)\n", "\n", "# We are only interested in the 'names' of keywords and genres\n", "def extract_names(dict_list):\n", " return [d['name'] for d in dict_list if 'name' in d]\n", "\n", "# Extract the 'name' values of the evaluated dictionaries\n", "evaluated_df['genres'] = evaluated_df['genres'].apply(extract_names)\n", "evaluated_df['keywords'] = evaluated_df['keywords'].apply(extract_names)\n", "\n", "# Convert the extracted 'name' lists into tuples (tuples work much better with Pandas dataframe functions)\n", "evaluated_df['genres'] = evaluated_df['genres'].apply(tuple)\n", "evaluated_df['keywords'] = evaluated_df['keywords'].apply(tuple)\n" ], "metadata": { "id": "sFDpPNtFVrsC" }, "execution_count": 53, "outputs": [] }, { "cell_type": "code", "source": [ "# After data selection, the final step should be the 'df' (what the next sections will work with)\n", "df = evaluated_df" ], "metadata": { "id": "3Q-YaxyDXgsU" }, "execution_count": 54, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Data Cleaning" ], "metadata": { "id": "wByCZNDfb5wm" } }, { "cell_type": "markdown", "source": [ "Filter out rows with missing values" ], "metadata": { "id": "kDGz7lq0Utqp" } }, { "cell_type": "code", "source": [ "no_empty_df = df.copy()" ], "metadata": { "id": "LVDpjecYnaMI" }, "execution_count": 55, "outputs": [] }, { "cell_type": "code", "source": [ "# Set 'no_empty_df' to be 'no_empty_df' where the 'genres' in each row has more than 0 list items\n", "no_empty_df = no_empty_df[no_empty_df['genres'].apply(lambda x: len(x) > 0)]\n", "\n", "# Set 'no_empty_df' to be 'no_empty_df' where the 'keywords' in each row has more than 0 list items\n", "no_empty_df = no_empty_df[no_empty_df['keywords'].apply(lambda x: len(x) > 0)]" ], "metadata": { "id": "PVf7m1aZoBBg" }, "execution_count": 56, "outputs": [] }, { "cell_type": "code", "source": [ "# Calculate how many rows were \"lost\"\n", "before = len(df)\n", "after = len(no_empty_df)\n", "\n", "print((before - after), 'rows with missing values removed')" ], "metadata": { "id": "lbo2wtXdpV-D", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "0fc6e67c-615d-4253-f7e4-9926d0ad07be" }, "execution_count": 57, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "15199 rows with missing values removed\n" ] } ] }, { "cell_type": "code", "source": [ "display(no_empty_df)" ], "metadata": { "id": "PaeSob9oL0mE", "colab": { "base_uri": "https://localhost:8080/", "height": 423 }, "outputId": "2b524bb4-db27-4a25-ddca-282e678bafeb" }, "execution_count": 58, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ " genres \\\n", "0 (Animation, Comedy, Family) \n", "1 (Adventure, Fantasy, Family) \n", "2 (Romance, Comedy) \n", "3 (Comedy, Drama, Romance) \n", "4 (Comedy,) \n", "... ... \n", "46472 (Horror, Mystery, Thriller) \n", "46473 (Mystery, Horror) \n", "46474 (Horror,) \n", "46477 (Drama, Family) \n", "46478 (Drama,) \n", "\n", " keywords \n", "0 (jealousy, toy, boy, friendship, friends, riva... \n", "1 (board game, disappearance, based on children'... \n", "2 (fishing, best friend, duringcreditsstinger, o... \n", "3 (based on novel, interracial relationship, sin... \n", "4 (baby, midlife crisis, confidence, aging, daug... \n", "... ... \n", "46472 (revenge, murder, serial killer, new york city... \n", "46473 (blair witch,) \n", "46474 (witch, mythology, legend, serial killer, mock... \n", "46477 (tragic love,) \n", "46478 (artist, play, pinoy) \n", "\n", "[31283 rows x 2 columns]" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
genreskeywords
0(Animation, Comedy, Family)(jealousy, toy, boy, friendship, friends, riva...
1(Adventure, Fantasy, Family)(board game, disappearance, based on children'...
2(Romance, Comedy)(fishing, best friend, duringcreditsstinger, o...
3(Comedy, Drama, Romance)(based on novel, interracial relationship, sin...
4(Comedy,)(baby, midlife crisis, confidence, aging, daug...
.........
46472(Horror, Mystery, Thriller)(revenge, murder, serial killer, new york city...
46473(Mystery, Horror)(blair witch,)
46474(Horror,)(witch, mythology, legend, serial killer, mock...
46477(Drama, Family)(tragic love,)
46478(Drama,)(artist, play, pinoy)
\n", "

31283 rows × 2 columns

\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "
\n", "
\n" ] }, "metadata": {} } ] }, { "cell_type": "markdown", "source": [ "Remove duplicate entries from the DataFrame" ], "metadata": { "id": "Q6nTDmakV2Px" } }, { "cell_type": "code", "source": [ "no_dupe_df = no_empty_df.drop_duplicates(keep='first')" ], "metadata": { "id": "ATgt4KQ0qND7" }, "execution_count": 59, "outputs": [] }, { "cell_type": "code", "source": [ "print(no_dupe_df)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aFbj8hykN0qi", "outputId": "3f67161d-3a4a-4562-ad31-5d835f000942" }, "execution_count": 60, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " genres \\\n", "0 (Animation, Comedy, Family) \n", "1 (Adventure, Fantasy, Family) \n", "2 (Romance, Comedy) \n", "3 (Comedy, Drama, Romance) \n", "4 (Comedy,) \n", "... ... \n", "46472 (Horror, Mystery, Thriller) \n", "46473 (Mystery, Horror) \n", "46474 (Horror,) \n", "46477 (Drama, Family) \n", "46478 (Drama,) \n", "\n", " keywords \n", "0 (jealousy, toy, boy, friendship, friends, riva... \n", "1 (board game, disappearance, based on children'... \n", "2 (fishing, best friend, duringcreditsstinger, o... \n", "3 (based on novel, interracial relationship, sin... \n", "4 (baby, midlife crisis, confidence, aging, daug... \n", "... ... \n", "46472 (revenge, murder, serial killer, new york city... \n", "46473 (blair witch,) \n", "46474 (witch, mythology, legend, serial killer, mock... \n", "46477 (tragic love,) \n", "46478 (artist, play, pinoy) \n", "\n", "[28251 rows x 2 columns]\n" ] } ] }, { "cell_type": "code", "source": [ "df = no_dupe_df" ], "metadata": { "id": "eC10c-4cYGQg" }, "execution_count": 61, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Data Balancing" ], "metadata": { "id": "9mDwq5AWcEcl" } }, { "cell_type": "code", "source": [ "category_analysis = df.copy()\n", "category_analysis['genre_lengths'] = category_analysis['genres'].apply(len)\n", "\n", "display(category_analysis)" ], "metadata": { "id": "JzB093uRcElc", "colab": { "base_uri": "https://localhost:8080/", "height": 423 }, "outputId": "a841737d-0fbb-4d15-a99c-02a9a9fcabb7" }, "execution_count": 62, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ " genres \\\n", "0 (Animation, Comedy, Family) \n", "1 (Adventure, Fantasy, Family) \n", "2 (Romance, Comedy) \n", "3 (Comedy, Drama, Romance) \n", "4 (Comedy,) \n", "... ... \n", "46472 (Horror, Mystery, Thriller) \n", "46473 (Mystery, Horror) \n", "46474 (Horror,) \n", "46477 (Drama, Family) \n", "46478 (Drama,) \n", "\n", " keywords genre_lengths \n", "0 (jealousy, toy, boy, friendship, friends, riva... 3 \n", "1 (board game, disappearance, based on children'... 3 \n", "2 (fishing, best friend, duringcreditsstinger, o... 2 \n", "3 (based on novel, interracial relationship, sin... 3 \n", "4 (baby, midlife crisis, confidence, aging, daug... 1 \n", "... ... ... \n", "46472 (revenge, murder, serial killer, new york city... 3 \n", "46473 (blair witch,) 2 \n", "46474 (witch, mythology, legend, serial killer, mock... 1 \n", "46477 (tragic love,) 2 \n", "46478 (artist, play, pinoy) 1 \n", "\n", "[28251 rows x 3 columns]" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
genreskeywordsgenre_lengths
0(Animation, Comedy, Family)(jealousy, toy, boy, friendship, friends, riva...3
1(Adventure, Fantasy, Family)(board game, disappearance, based on children'...3
2(Romance, Comedy)(fishing, best friend, duringcreditsstinger, o...2
3(Comedy, Drama, Romance)(based on novel, interracial relationship, sin...3
4(Comedy,)(baby, midlife crisis, confidence, aging, daug...1
............
46472(Horror, Mystery, Thriller)(revenge, murder, serial killer, new york city...3
46473(Mystery, Horror)(blair witch,)2
46474(Horror,)(witch, mythology, legend, serial killer, mock...1
46477(Drama, Family)(tragic love,)2
46478(Drama,)(artist, play, pinoy)1
\n", "

28251 rows × 3 columns

\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "
\n", "
\n" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "plt.ylabel('No. of Movies')\n", "plt.xlabel('No. of Genres')\n", "\n", "plt.title(\"Multi-Genre Distribution\")\n", "plt.hist(category_analysis['genre_lengths'], bins=10)\n", "plt.show()" ], "metadata": { "id": "zBYqWZPyeEPY", "colab": { "base_uri": "https://localhost:8080/", "height": 472 }, "outputId": "9827a695-92af-4f72-fed9-7c1c80dc24be" }, "execution_count": 63, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "combination_counts = category_analysis.groupby('genres').size()\n", "\n", "for genres_tuple, count in combination_counts.items():\n", " print(f\"Genres Tuple: {genres_tuple}, Count: {count}\")" ], "metadata": { "id": "726LDHbfd6k8" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "single_category = category_analysis[category_analysis['genre_lengths'] == 1]\n", "double_category = category_analysis[category_analysis['genre_lengths'] == 2]\n", "triple_category = category_analysis[category_analysis['genre_lengths'] == 3]\n", "\n", "#20 genres\n", "single_category_stats = single_category.groupby('genres').size()\n", "display(single_category_stats.mean())\n", "\n", "#20 nCr 2 = 190 / actual = 160\n", "double_category_stats = double_category.groupby('genres').size()\n", "display(double_category_stats.describe())\n", "\n", "#20 nCr 3 = 1140 / actual 468\n", "triple_category_stats = triple_category.groupby('genres').size()\n", "display(triple_category_stats.describe())\n", "\n", "category_analysis_stats = category_analysis.groupby('genres').size()" ], "metadata": { "id": "-mMEtUYChw67" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from collections import Counter\n", "\n", "double_combinations = set()\n", "unique_genres = set()\n", "\n", "for _, row in double_category.iterrows():\n", " double_combinations.add(tuple(sorted(row['genres'])))\n", " unique_genres.update(row['genres'])\n", "\n", "all_genres = []\n", "for genre_tuple in double_combinations:\n", " all_genres.extend(genre_tuple)\n", "\n", "genre_counts = Counter(all_genres)\n", "\n", "genres = list(genre_counts.keys())\n", "counts = list(genre_counts.values())\n", "\n", "plt.figure(figsize=(12, 6))\n", "plt.bar(genres, counts)\n", "\n", "plt.title('2 Genre Combinations')\n", "plt.xlabel('Genre Name')\n", "plt.ylabel('No. of Combinations')\n", "\n", "plt.xticks(rotation=90)\n", "\n", "plt.show()" ], "metadata": { "id": "gBO4cqH_f-fk", "colab": { "base_uri": "https://localhost:8080/", "height": 608 }, "outputId": "c2e60772-a38d-41c2-cfa9-28a09d4b6163" }, "execution_count": 66, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "from collections import Counter\n", "\n", "triple_combinations = set()\n", "unique_genres = set()\n", "\n", "for _, row in triple_category.iterrows():\n", " triple_combinations.add(tuple(sorted(row['genres'])))\n", " unique_genres.update(row['genres'])\n", "\n", "all_genres = []\n", "for genre_tuple in triple_combinations:\n", " all_genres.extend(genre_tuple)\n", "\n", "genre_counts = Counter(all_genres)\n", "\n", "genres = list(genre_counts.keys())\n", "counts = list(genre_counts.values())\n", "\n", "plt.figure(figsize=(12, 6))\n", "plt.bar(genres, counts)\n", "\n", "plt.title('3 Genre Combinations')\n", "plt.xlabel('Genre Name')\n", "plt.ylabel('No. of Combinations')\n", "\n", "plt.xticks(rotation=90)\n", "\n", "plt.show()" ], "metadata": { "id": "am3AlYNA7-cv", "colab": { "base_uri": "https://localhost:8080/", "height": 611 }, "outputId": "e1309d35-e520-49ea-b90b-cf68da125542" }, "execution_count": 67, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "##balancing single genres\n", "\n", "general_mean = single_category.groupby('genres').size()\n", "print(general_mean)\n", "\n", "sc_mean = int(single_category_stats.mean())\n", "\n", "balanced_df= pd.DataFrame()\n", "\n", "for genre in unique_genres:\n", " single_genre_df = single_category[single_category['genres'].apply(lambda x: genre in x)]\n", "\n", " count1 = len(single_genre_df)\n", "\n", " if count1 > sc_mean:\n", " # Undersample\n", " sampled_df = single_genre_df.sample(n=sc_mean, random_state=RANDOM_SEED)\n", " elif count1 < sc_mean:\n", " # Oversample\n", " sampled_df = single_genre_df.sample(n=sc_mean, replace=True, random_state=RANDOM_SEED)\n", " else:\n", " sampled_df = single_genre_df\n", "\n", " balanced_df = pd.concat([balanced_df, sampled_df], ignore_index=True)" ], "metadata": { "id": "6c4HRBY_TfzC", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "dfd9bc41-11b0-4e30-8cdb-c16ae9cce60f" }, "execution_count": 68, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "genres\n", "(Action,) 155\n", "(Adventure,) 81\n", "(Animation,) 111\n", "(Comedy,) 1702\n", "(Crime,) 101\n", "(Documentary,) 1294\n", "(Drama,) 2845\n", "(Family,) 37\n", "(Fantasy,) 42\n", "(Foreign,) 3\n", "(History,) 8\n", "(Horror,) 675\n", "(Music,) 61\n", "(Mystery,) 46\n", "(Romance,) 54\n", "(Science Fiction,) 139\n", "(TV Movie,) 3\n", "(Thriller,) 258\n", "(War,) 23\n", "(Western,) 252\n", "dtype: int64\n" ] } ] }, { "cell_type": "code", "source": [ "#Balancing Double-combinations\n", "\n", "dc_mean = int(double_category_stats.mean())\n", "\n", "for genre_combinations in double_combinations:\n", "\n", " double_genre_df = double_category[double_category['genres'].apply(lambda x: all(genre in x for genre in genre_combinations))]\n", "\n", " count = len(double_genre_df)\n", "\n", " if count > dc_mean:\n", " # Undersample\n", " sampled_df = double_genre_df.sample(n=dc_mean, random_state=RANDOM_SEED)\n", " elif count < dc_mean:\n", " # Oversample\n", " sampled_df = double_genre_df.sample(n=dc_mean, replace=True, random_state=RANDOM_SEED)\n", " else:\n", " sampled_df = double_genre_df\n", "\n", " balanced_df = pd.concat([balanced_df, sampled_df], ignore_index=True)" ], "metadata": { "id": "OYvSbbOH19NN" }, "execution_count": 69, "outputs": [] }, { "cell_type": "code", "source": [ "tc_mean = int(triple_category_stats.mean())\n", "\n", "for genre_combinations in triple_combinations:\n", "\n", " triple_genre_df = triple_category[triple_category['genres'].apply(lambda x: all(genre in x for genre in genre_combinations))]\n", "\n", " count = len(triple_genre_df)\n", "\n", " if count > tc_mean:\n", " # Undersample\n", " sampled_df = triple_genre_df.sample(n=tc_mean, random_state=RANDOM_SEED)\n", " elif count < tc_mean:\n", " # Oversample\n", " sampled_df = triple_genre_df.sample(n=tc_mean, replace=True, random_state=RANDOM_SEED)\n", " else:\n", " sampled_df = triple_genre_df\n", "\n", " balanced_df = pd.concat([balanced_df, sampled_df], ignore_index=True)" ], "metadata": { "id": "R8obi3NT7DiT" }, "execution_count": 70, "outputs": [] }, { "cell_type": "code", "source": [ "# Exploding the list_column\n", "exploded_df = balanced_df.explode('genres')\n", "\n", "# Group by the exploded column and count occurrences\n", "grouped = exploded_df.groupby('genres').size()\n", "print(grouped.index)\n", "\n", "y_pos = np.arange(len(grouped))\n", "\n", "# Calculate the mean and median (using numpy)\n", "mean_occurrences = np.mean(grouped)\n", "median_occurrences = np.median(grouped)\n", "\n", "# Create the bar chart\n", "plt.bar(y_pos, grouped, align='center', alpha=0.5)\n", "\n", "# Adding the labels and title\n", "plt.xticks(y_pos, grouped.index, rotation='vertical')\n", "plt.ylabel('No. of entries')\n", "plt.title('Movie Genre Distribution')\n", "\n", "# Add the mean and median lines (red and green respectively)\n", "plt.axhline(y=mean_occurrences, color='r', linestyle='-', label=f'Mean: {mean_occurrences:.2f}')\n", "plt.axhline(y=median_occurrences, color='g', linestyle='--', label=f'Median: {median_occurrences:.2f}')\n", "\n", "# Show illustration\n", "plt.legend()\n", "plt.show()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 628 }, "id": "-DTwfZTQ8umI", "outputId": "a05e91f3-ee1c-47a2-c8d6-f95f3369f044" }, "execution_count": 71, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Index(['Action', 'Adventure', 'Animation', 'Comedy', 'Crime', 'Documentary',\n", " 'Drama', 'Family', 'Fantasy', 'Foreign', 'History', 'Horror', 'Music',\n", " 'Mystery', 'Romance', 'Science Fiction', 'TV Movie', 'Thriller', 'War',\n", " 'Western'],\n", " dtype='object', name='genres')\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "df = balanced_df" ], "metadata": { "id": "ABM7lXP2kaU5" }, "execution_count": 72, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Data Normalization" ], "metadata": { "id": "dbdr0WMATGxg" } }, { "cell_type": "code", "source": [ "normalized_df = df.copy()" ], "metadata": { "id": "wLSm3rfOmlOG" }, "execution_count": 73, "outputs": [] }, { "cell_type": "markdown", "source": [ "Drop the 'genre_lengths' column (used for data balancing)" ], "metadata": { "id": "nniSshz5m0nt" } }, { "cell_type": "code", "source": [ "normalized_df.drop('genre_lengths', axis=1, inplace=True)" ], "metadata": { "id": "2cZ6GiCRm5EW" }, "execution_count": 74, "outputs": [] }, { "cell_type": "code", "source": [ "import torch" ], "metadata": { "id": "LL-aJtSloCxF" }, "execution_count": 75, "outputs": [] }, { "cell_type": "code", "source": [ "unique_genres = set(genre for sublist in df['genres'] for genre in sublist)\n", "unique_genres = sorted(list(unique_genres))\n", "\n", "genre_to_index = {genre: i for i, genre in enumerate(unique_genres)}\n", "\n", "import numpy as np\n", "\n", "def encode_genres(genres, genre_to_index):\n", " encoding = np.zeros(len(genre_to_index), dtype=int)\n", " for genre in genres:\n", " encoding[genre_to_index[genre]] = 1\n", " return encoding\n", "\n", "# Apply encoding to each row\n", "normalized_df['genres'] = normalized_df['genres'].apply(lambda x: encode_genres(x, genre_to_index))\n" ], "metadata": { "id": "ScuklEfAn_ad" }, "execution_count": 76, "outputs": [] }, { "cell_type": "markdown", "source": [ "Convert the keyword tuples to text strings" ], "metadata": { "id": "CFRkQBbZrW29" } }, { "cell_type": "code", "source": [ "normalized_df['keywords'] = normalized_df['keywords'].apply(lambda x: ', '.join(map(str, x)))" ], "metadata": { "id": "0okiGBSrrYjb" }, "execution_count": 77, "outputs": [] }, { "cell_type": "code", "source": [ "df = normalized_df" ], "metadata": { "id": "0UFbquTlsZjH" }, "execution_count": 78, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Dataset Splitting" ], "metadata": { "id": "UsvEz0a81eey" } }, { "cell_type": "code", "source": [ "from datasets import DatasetDict, Dataset\n", "from sklearn.model_selection import train_test_split\n", "\n", "df['labels'] = df['genres']\n", "df.drop('genres', axis=1, inplace=True)\n", "dataset = Dataset.from_pandas(df)\n", "# Split the dataset into training and testing sets\n", "split_dataset = dataset.train_test_split(test_size=0.2, seed=RANDOM_SEED)" ], "metadata": { "id": "lAFvmm-H8OYi" }, "execution_count": 79, "outputs": [] }, { "cell_type": "code", "source": [ "from transformers import DistilBertTokenizer\n", "\n", "\n", "# Initialize the tokenizer\n", "tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n", "\n", "# Function to tokenize the text data\n", "def tokenize_function(examples):\n", " return tokenizer(examples['keywords'], padding='max_length', truncation=True, max_length=512)\n", "\n", "# Apply the tokenization to the dataset\n", "tokenized_datasets = split_dataset.map(tokenize_function, batched=True)" ], "metadata": { "id": "HDPcn2hc90r-", "colab": { "base_uri": "https://localhost:8080/", "referenced_widgets": [ "ac5a633935c04013a0a850fe39975116", "61e818edc44c44e8a3d8303eee8ff91c", "672ea5d57f04454a8b96ab13232d5237", "49dbac60cb814000a5aa59985466fd81", "949c68abb0784a5b911a5e03d1f429ad", "e4ffaccaa9184c83837e36be0beb30f5", "f55ed7599ef546e2a02a2d83beae41d7", "65be8f1e2cc042a98003cb68a99e46d0", "946454de1bf543f8911aea42cd0d64dc", "7435c281bf8f4480b3542eb34d8826c9", "b7189a93516242edb18e739331bcb1d7", "8adfee87f17445159c1fdced536dc85c", "071989bc2f3b4ea8a0ca7ea871844aca", "caa8dc7da70c4a91940f72647c0855ed", "389215a286de468eb666908e98ce8c5d", "7bb192ab8d8d4a4d87958e35c78e4fd8", "c6b087846a6a4b8abe1bc04593e2edc9", "371a00a75b9745c1bc50d412731a5f8a", "de656fc10762480aa127286dc10d0c95", "66e71a906bfc4f679ef2e7c3361c7cac", "01c50a06ef2f45f7b4ceabb7e7583416", "328167df4a454a7fa65f12d91de86b6b" ], "height": 81 }, "outputId": "3a7259ae-623c-49c8-8ede-c1b45f400aaa" }, "execution_count": 80, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Map: 0%| | 0/12528 [00:00 0.5).astype(int)\n", " # Calculate metrics\n", " precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='samples')\n", " acc = accuracy_score(labels, predictions)\n", " hamming = hamming_loss(labels, predictions)\n", "\n", " return {\n", " 'accuracy': acc,\n", " 'f1': f1,\n", " 'precision': precision,\n", " 'recall': recall,\n", " 'hamming_loss': hamming\n", " }\n" ], "metadata": { "id": "y1etEIiDtfZR" }, "execution_count": 91, "outputs": [] }, { "cell_type": "code", "source": [ "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=test_dataset,\n", " data_collator=custom_collator,\n", " compute_metrics=compute_metrics,\n", ")" ], "metadata": { "id": "pDDtLfuBB0vT" }, "execution_count": 92, "outputs": [] }, { "cell_type": "code", "source": [ "display(train_dataset)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 86 }, "id": "EXohfDa_zYFA", "outputId": "9ef26bc9-d514-40bc-e727-2c66f826ce7f" }, "execution_count": 93, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Dataset({\n", " features: ['keywords', 'labels', 'input_ids', 'attention_mask'],\n", " num_rows: 12528\n", "})" ] }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "trainer.train()" ], "metadata": { "id": "9Zzlf8j4B24L", "colab": { "base_uri": "https://localhost:8080/", "height": 248 }, "outputId": "796be6d8-d0c0-4eb6-b148-d2e8c320f6eb" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ ":8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch['labels'] = torch.tensor(batch['labels'], dtype=torch.float)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "
\n", " \n", " \n", " [ 299/1176 03:43 < 11:00, 1.33 it/s, Epoch 1.52/6]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyF1PrecisionRecallHamming LossRuntimeSamples Per SecondSteps Per Second
1No log0.2772960.0000000.0000000.0000000.0000000.08183312.513300250.2940003.916000

" ] }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", ":8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch['labels'] = torch.tensor(batch['labels'], dtype=torch.float)\n" ] } ] }, { "cell_type": "code", "source": [ "results = trainer.evaluate()" ], "metadata": { "id": "Qns4wSQxtrR5", "colab": { "base_uri": "https://localhost:8080/", "height": 126 }, "outputId": "d0ca8abe-5a92-420a-f35f-ea38efaac91c" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ ":8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch['labels'] = torch.tensor(batch['labels'], dtype=torch.float)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "

\n", " \n", " \n", " [98/98 00:12]\n", "
\n", " " ] }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n" ] } ] }, { "cell_type": "code", "source": [ "print(results)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uU3d40U04cdy", "outputId": "152c3201-ba21-4d96-8674-ff08c29519ac" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'eval_loss': 0.11773193627595901, 'eval_accuracy': 0.6669859514687101, 'eval_f1': 0.7603630724320378, 'eval_precision': 0.7963761174968071, 'eval_recall': 0.7502128565346956, 'eval_hamming_loss': 0.03204022988505747, 'eval_runtime': 12.7661, 'eval_samples_per_second': 245.338, 'eval_steps_per_second': 7.677, 'epoch': 6.0}\n" ] } ] }, { "cell_type": "code", "source": [ "model.save_pretrained('./saved_model')\n", "tokenizer.save_pretrained('./saved_model')" ], "metadata": { "id": "vNxaJzgMBx4R", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "b2c626aa-78bc-434d-8893-68627527f313" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "('./saved_model/tokenizer_config.json',\n", " './saved_model/special_tokens_map.json',\n", " './saved_model/vocab.txt',\n", " './saved_model/added_tokens.json',\n", " './saved_model/tokenizer.json')" ] }, "metadata": {}, "execution_count": 176 } ] }, { "cell_type": "code", "source": [ "!zip -r /content/model.zip /content/saved_model" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PoSveDHL4uCH", "outputId": "033de9d3-34c0-44be-95fb-38a09f870366" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " adding: content/saved_model/ (stored 0%)\n", " adding: content/saved_model/config.json (deflated 62%)\n", " adding: content/saved_model/vocab.txt (deflated 53%)\n", " adding: content/saved_model/model.safetensors (deflated 8%)\n", " adding: content/saved_model/tokenizer_config.json (deflated 76%)\n", " adding: content/saved_model/tokenizer.json (deflated 71%)\n", " adding: content/saved_model/special_tokens_map.json (deflated 42%)\n" ] } ] } ] }