{"cells":[{"cell_type":"markdown","metadata":{},"source":["Note: This code is run on Kaggle environment."]},{"cell_type":"markdown","metadata":{},"source":["# Importing the required libraries"]},{"cell_type":"code","execution_count":1,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:20:55.138911Z","iopub.status.busy":"2023-12-25T09:20:55.138540Z","iopub.status.idle":"2023-12-25T09:21:13.497198Z","shell.execute_reply":"2023-12-25T09:21:13.496367Z","shell.execute_reply.started":"2023-12-25T09:20:55.138880Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["/opt/conda/lib/python3.10/site-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n"," warnings.warn(_BETA_TRANSFORMS_WARNING)\n","/opt/conda/lib/python3.10/site-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().\n"," warnings.warn(_BETA_TRANSFORMS_WARNING)\n","/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.3\n"," warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"]}],"source":["import torch\n","import pandas as pd\n","import numpy as np\n","import os\n","import warnings\n","import matplotlib.pyplot as plt\n","\n","from transformers import AutoTokenizer, AutoModelForSequenceClassification, DistilBertForSequenceClassification, AutoModelForSeq2SeqLM\n","from tqdm import tqdm\n","from torchvision import models\n","from torchvision.transforms import v2\n","from torch.utils.data import Dataset, DataLoader\n","from keras.preprocessing import image\n","from torchmetrics.classification import MultilabelF1Score\n","from sklearn.metrics import average_precision_score, ndcg_score"]},{"cell_type":"markdown","metadata":{},"source":["### Setting up the environment\n","***"]},{"cell_type":"code","execution_count":2,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:13.499630Z","iopub.status.busy":"2023-12-25T09:21:13.499089Z","iopub.status.idle":"2023-12-25T09:21:13.503998Z","shell.execute_reply":"2023-12-25T09:21:13.502923Z","shell.execute_reply.started":"2023-12-25T09:21:13.499602Z"},"trusted":true},"outputs":[],"source":["warnings.filterwarnings(\"ignore\")"]},{"cell_type":"markdown","metadata":{},"source":["***"]},{"cell_type":"markdown","metadata":{},"source":["# Data Preprocessing"]},{"cell_type":"markdown","metadata":{},"source":["18 Genres from file genres.txt"]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:13.505808Z","iopub.status.busy":"2023-12-25T09:21:13.505429Z","iopub.status.idle":"2023-12-25T09:21:13.538838Z","shell.execute_reply":"2023-12-25T09:21:13.537882Z","shell.execute_reply.started":"2023-12-25T09:21:13.505773Z"},"trusted":true},"outputs":[{"data":{"text/plain":["{0: 'Crime',\n"," 1: 'Thriller',\n"," 2: 'Fantasy',\n"," 3: 'Horror',\n"," 4: 'Sci-Fi',\n"," 5: 'Comedy',\n"," 6: 'Documentary',\n"," 7: 'Adventure',\n"," 8: 'Film-Noir',\n"," 9: 'Animation',\n"," 10: 'Romance',\n"," 11: 'Drama',\n"," 12: 'Western',\n"," 13: 'Musical',\n"," 14: 'Action',\n"," 15: 'Mystery',\n"," 16: 'War',\n"," 17: \"Children's\"}"]},"execution_count":3,"metadata":{},"output_type":"execute_result"}],"source":["genres = [\"Crime\", \"Thriller\", \"Fantasy\", \"Horror\", \"Sci-Fi\", \"Comedy\", \"Documentary\", \"Adventure\", \"Film-Noir\", \"Animation\", \"Romance\", \"Drama\", \"Western\", \"Musical\", \"Action\", \"Mystery\", \"War\", \"Children\\'s\"]\n","mapping = {}\n","for i in range(len(genres)):\n"," mapping[i] = genres[i]\n","mapping"]},{"cell_type":"markdown","metadata":{},"source":["***"]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:13.541348Z","iopub.status.busy":"2023-12-25T09:21:13.541044Z","iopub.status.idle":"2023-12-25T09:21:13.754109Z","shell.execute_reply":"2023-12-25T09:21:13.753302Z","shell.execute_reply.started":"2023-12-25T09:21:13.541322Z"},"trusted":true},"outputs":[],"source":["trainset = pd.read_csv('/kaggle/input/ml-dataset-2023s1/trainset.csv')\n","testset = pd.read_csv('/kaggle/input/ml-dataset-2023s1/testset.csv')\n","trainset.label = trainset.label.apply(lambda x: eval(x))\n","testset.label = testset.label.apply(lambda x: eval(x))\n","trainset.img_path = trainset.img_path.apply(lambda x: x.replace('\\\\', '/'))\n","testset.img_path = testset.img_path.apply(lambda x: x.replace('\\\\', '/'))"]},{"cell_type":"markdown","metadata":{},"source":["This is actually the dataset given by our lecturer. We decided to push this dataset privately on Kaggle to be able to run this code on Kaggle.\n","The Dataset has 3106 rows in trainset and 777 rows in testset.\n","However, since generating movie plots each run costs over 2 hours to finish, we had generated it and then saved it to 2 .csv files."]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:13.755783Z","iopub.status.busy":"2023-12-25T09:21:13.755489Z","iopub.status.idle":"2023-12-25T09:21:13.760684Z","shell.execute_reply":"2023-12-25T09:21:13.759728Z","shell.execute_reply.started":"2023-12-25T09:21:13.755757Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["3106 777\n"]}],"source":["print(len(trainset), len(testset))"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:13.762059Z","iopub.status.busy":"2023-12-25T09:21:13.761750Z","iopub.status.idle":"2023-12-25T09:21:21.423701Z","shell.execute_reply":"2023-12-25T09:21:21.422523Z","shell.execute_reply.started":"2023-12-25T09:21:13.762035Z"},"trusted":true},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"2b19aad246a6453daf826a28d6fc8b66","version_major":2,"version_minor":0},"text/plain":["tokenizer_config.json: 0%| | 0.00/2.50k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"6b2da802377e42b397a49608b79423a1","version_major":2,"version_minor":0},"text/plain":["spiece.model: 0%| | 0.00/792k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"c89684248ea14925800ba5b41c2209ca","version_major":2,"version_minor":0},"text/plain":["tokenizer.json: 0%| | 0.00/2.42M [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"ca7d547bb02a4f0dbd0296ccde13e17b","version_major":2,"version_minor":0},"text/plain":["special_tokens_map.json: 0%| | 0.00/2.20k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"0f6dc7ead35e496f8bc9ffe386bb4045","version_major":2,"version_minor":0},"text/plain":["config.json: 0%| | 0.00/1.53k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"8f2524006375450ab80d6d52b4ce50ba","version_major":2,"version_minor":0},"text/plain":["pytorch_model.bin: 0%| | 0.00/990M [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"0d86467317c9498fa2606fe9e87eaa89","version_major":2,"version_minor":0},"text/plain":["generation_config.json: 0%| | 0.00/142 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"}],"source":["tokenizer_gen = AutoTokenizer.from_pretrained(\"MBZUAI/LaMini-Flan-T5-248M\")\n","model_gen = AutoModelForSeq2SeqLM.from_pretrained(\"MBZUAI/LaMini-Flan-T5-248M\")"]},{"cell_type":"code","execution_count":7,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:21.425553Z","iopub.status.busy":"2023-12-25T09:21:21.425168Z","iopub.status.idle":"2023-12-25T09:21:21.435369Z","shell.execute_reply":"2023-12-25T09:21:21.434216Z","shell.execute_reply.started":"2023-12-25T09:21:21.425523Z"},"trusted":true},"outputs":[],"source":["def generate_plot(df: pd.DataFrame, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer, device) -> pd.DataFrame:\n"," quote = 'What is the story of the movie {}?'\n"," model_gen.to(device)\n"," model_gen.eval()\n","\n"," for i in tqdm(range(len(df))):\n"," with torch.no_grad():\n"," input_ids = tokenizer(quote.format(df.title[i]), return_tensors='pt').input_ids.to(device)\n"," output = model.generate(input_ids, max_length=256, do_sample=True, temperature=0.09)\n"," df.loc[i, 'plot'] = tokenizer.decode(output[0], skip_special_tokens=True)\n"," return df"]},{"cell_type":"markdown","metadata":{},"source":["This is our unused function to generate movie plots. Below is how we used it to generate data."]},{"cell_type":"code","execution_count":8,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:21.437298Z","iopub.status.busy":"2023-12-25T09:21:21.436868Z","iopub.status.idle":"2023-12-25T09:21:22.492901Z","shell.execute_reply":"2023-12-25T09:21:22.491470Z","shell.execute_reply.started":"2023-12-25T09:21:21.437264Z"},"trusted":true},"outputs":[],"source":["device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"]},{"cell_type":"code","execution_count":9,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:22.495575Z","iopub.status.busy":"2023-12-25T09:21:22.494953Z","iopub.status.idle":"2023-12-25T09:21:22.612963Z","shell.execute_reply":"2023-12-25T09:21:22.611570Z","shell.execute_reply.started":"2023-12-25T09:21:22.495522Z"},"trusted":true},"outputs":[],"source":["# trainset = generate_plot(trainset, model_gen, tokenizer_gen, device)\n","# testset = generate_plot(testset, model_gen, tokenizer_gen, device)"]},{"cell_type":"markdown","metadata":{},"source":["# Model Implementation"]},{"cell_type":"markdown","metadata":{},"source":["### Sub-models\n","***"]},{"cell_type":"code","execution_count":10,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:22.620906Z","iopub.status.busy":"2023-12-25T09:21:22.620534Z","iopub.status.idle":"2023-12-25T09:21:34.001085Z","shell.execute_reply":"2023-12-25T09:21:34.000036Z","shell.execute_reply.started":"2023-12-25T09:21:22.620875Z"},"trusted":true},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"ffc2604f01a1458e874f6228924fbc11","version_major":2,"version_minor":0},"text/plain":["tokenizer_config.json: 0%| | 0.00/28.0 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"7c177b5a2b444c8cb7b52f06d3fca821","version_major":2,"version_minor":0},"text/plain":["config.json: 0%| | 0.00/483 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"8f0a74a4449542328defea509acd8e24","version_major":2,"version_minor":0},"text/plain":["vocab.txt: 0%| | 0.00/232k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"d2468b0ed5434e8a85be2f2faf987ee7","version_major":2,"version_minor":0},"text/plain":["tokenizer.json: 0%| | 0.00/466k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"6898208cae67445a961631cc37fda87d","version_major":2,"version_minor":0},"text/plain":["model.safetensors: 0%| | 0.00/268M [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"name":"stderr","output_type":"stream","text":["Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n","You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"]},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"f7ab182a4cfe475f8a0d2b393f974206","version_major":2,"version_minor":0},"text/plain":["tokenizer_config.json: 0%| | 0.00/1.20k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"84f17500b602407bae355f01437c82a8","version_major":2,"version_minor":0},"text/plain":["vocab.txt: 0%| | 0.00/232k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"f9434012204e4523b7899a947ae12266","version_major":2,"version_minor":0},"text/plain":["tokenizer.json: 0%| | 0.00/712k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"8cc89f3d59094688ac76c7144c3b5a5d","version_major":2,"version_minor":0},"text/plain":["special_tokens_map.json: 0%| | 0.00/125 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"058575f187e24b52b0d2efe549fba600","version_major":2,"version_minor":0},"text/plain":["config.json: 0%| | 0.00/1.36k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"53956b30d7be449cb2aecdfa032b8f1d","version_major":2,"version_minor":0},"text/plain":["model.safetensors: 0%| | 0.00/268M [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/plain":["device(type='cuda')"]},"execution_count":10,"metadata":{},"output_type":"execute_result"}],"source":["tokenizer1 = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n","model1 = DistilBertForSequenceClassification .from_pretrained(\"distilbert-base-uncased\", problem_type=\"multi_label_classification\", num_labels=18)\n","model1.config.id2label = mapping\n","\n","tokenizer2 = AutoTokenizer.from_pretrained(\"dduy193/plot-classification\")\n","model2 = AutoModelForSequenceClassification.from_pretrained(\"dduy193/plot-classification\")\n","model2.config.id2label = mapping\n","\n","model3 = models.resnet101(pretrained=False)\n","model3.fc = torch.nn.Linear(2048, len(genres))\n","\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n","model1.to(device)\n","model2.to(device)\n","model3.to(device)\n","device"]},{"cell_type":"markdown","metadata":{},"source":["### Deep Fusion Multimodal Model\n","***"]},{"cell_type":"code","execution_count":11,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:34.003002Z","iopub.status.busy":"2023-12-25T09:21:34.002690Z","iopub.status.idle":"2023-12-25T09:21:34.011225Z","shell.execute_reply":"2023-12-25T09:21:34.010061Z","shell.execute_reply.started":"2023-12-25T09:21:34.002960Z"},"trusted":true},"outputs":[],"source":["class Multimodal(torch.nn.Module):\n"," def __init__(self, model1, model2, model3):\n"," super().__init__()\n"," self.model1 = model1\n"," self.model2 = model2\n"," self.model3 = model3\n"," self.fc1 = torch.nn.Linear(18, 18)\n"," self.fc2 = torch.nn.Linear(18, 18)\n"," self.fc3 = torch.nn.Linear(18, 18)\n","\n"," def forward(self, \n"," title_input_ids, title_attention_mask,\n"," plot_input_ids, plot_attention_mask,\n"," image_input):\n"," title_output = self.model1(title_input_ids, title_attention_mask)\n"," plot_output = self.model2(plot_input_ids, plot_attention_mask)\n"," image_output = self.model3(image_input)\n","\n"," title_output = self.fc1(title_output.logits)\n"," plot_output = self.fc2(plot_output.logits)\n"," image_output = self.fc3(image_output)\n"," \n"," output = torch.add(title_output, plot_output)\n"," output = torch.add(output, image_output)\n"," return output"]},{"cell_type":"markdown","metadata":{},"source":["# Custom Datasets & Data Loaders"]},{"cell_type":"markdown","metadata":{},"source":["***\n","### Custom Dataset\n","***"]},{"cell_type":"code","execution_count":12,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:34.012715Z","iopub.status.busy":"2023-12-25T09:21:34.012432Z","iopub.status.idle":"2023-12-25T09:21:34.031474Z","shell.execute_reply":"2023-12-25T09:21:34.030400Z","shell.execute_reply.started":"2023-12-25T09:21:34.012691Z"},"trusted":true},"outputs":[],"source":["class Poroset(torch.utils.data.Dataset):\n"," def __init__(self, df, \n"," tokenizer1, tokenizer2, \n"," max_len1=64, max_len2=256,\n"," device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):\n"," self.df = df\n"," self.tokenizer1 = tokenizer1\n"," self.tokenizer2 = tokenizer2\n"," self.max_len1 = max_len1\n"," self.max_len2 = max_len2\n"," self.device = device\n"," self.transform = v2.Compose([\n"," v2.Resize((224, 224)),\n"," v2.ToTensor(),\n"," v2.Normalize(mean=[0.485, 0.456, 0.406],\n"," std=[0.229, 0.224, 0.225])\n"," ])\n","\n"," def __len__(self):\n"," return len(self.df)\n"," \n"," def __getitem__(self, idx):\n"," row = self.df.iloc[idx]\n"," \n"," title = row['title']\n"," # Truncate title if it is too long\n"," if len(title) > self.max_len1:\n"," title = title[:self.max_len1]\n","\n"," plot = row['plot']\n"," # Truncate plot if it is too long\n"," if len(plot) > self.max_len2:\n"," plot = plot[:self.max_len2]\n","\n"," label = row['label']\n"," title_encoding = self.tokenizer1(title, truncation=True, padding='max_length', max_length=self.max_len1, return_tensors='pt')\n"," plot_encoding = self.tokenizer2(plot, truncation=True, padding='max_length', max_length=self.max_len2, return_tensors='pt')\n"," \n"," image_path = '/kaggle/input/ml-dataset-2023s1/ml1m/' + row['img_path']\n"," if os.path.exists(image_path):\n"," image_input = image.load_img(image_path)\n"," image_input = self.transform(image_input)\n"," else:\n"," image_input = torch.zeros((3, 224, 224))\n"," \n"," return {\n"," 'title': title,\n"," 'plot': plot,\n"," 'title_input_ids': title_encoding['input_ids'].squeeze(),\n"," 'title_attention_mask': title_encoding['attention_mask'].squeeze(),\n"," 'plot_input_ids': plot_encoding['input_ids'].squeeze(),\n"," 'plot_attention_mask': plot_encoding['attention_mask'].squeeze(),\n"," 'image_input': image_input,\n"," 'label': torch.FloatTensor(label)\n"," }"]},{"cell_type":"code","execution_count":13,"metadata":{"execution":{"iopub.execute_input":"2023-12-25T09:21:34.033238Z","iopub.status.busy":"2023-12-25T09:21:34.032817Z","iopub.status.idle":"2023-12-25T09:21:34.058795Z","shell.execute_reply":"2023-12-25T09:21:34.057816Z","shell.execute_reply.started":"2023-12-25T09:21:34.033210Z"},"trusted":true},"outputs":[{"data":{"text/html":["
\n"," | title | \n","img_path | \n","label | \n","plot | \n","
---|---|---|---|---|
0 | \n","Washington Square (1997) | \n","ml1m/content/dataset/ml1m-images/1650.jpg | \n","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ... | \n","Washington Square is a 1997 American film abou... | \n","
1 | \n","Net, The (1995) | \n","ml1m/content/dataset/ml1m-images/185.jpg | \n","[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | \n","Net is a 1995 American film directed by James ... | \n","
2 | \n","Batman Returns (1992) | \n","ml1m/content/dataset/ml1m-images/1377.jpg | \n","[1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, ... | \n","Batman returns to the Batman universe after a ... | \n","
3 | \n","Boys from Brazil, The (1978) | \n","ml1m/content/dataset/ml1m-images/3204.jpg | \n","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | \n","The movie Boys from Brazil, The (1978) is a ro... | \n","
4 | \n","Dear Jesse (1997) | \n","ml1m/content/dataset/ml1m-images/1901.jpg | \n","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ... | \n","Dear Jesse is a 1997 American drama film about... | \n","