{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3e7b6247", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2023-06-29 09:08:24,868] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n", "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please run\n", "\n", "python -m bitsandbytes\n", "\n", " and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "bin /home/sourab/miniconda3/envs/ml/lib/python3.11/site-packages/bitsandbytes/libbitsandbytes_cuda118.so\n", "CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n", "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n", "CUDA SETUP: Detected CUDA version 118\n", "CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.11/site-packages/bitsandbytes/libbitsandbytes_cuda118.so...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/sourab/miniconda3/envs/ml/lib/python3.11/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: Found duplicate ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] files: {PosixPath('/home/sourab/miniconda3/envs/ml/lib/libcudart.so'), PosixPath('/home/sourab/miniconda3/envs/ml/lib/libcudart.so.11.0')}.. We'll flip a coin and try one of these, in order to fail forward.\n", "Either way, this might cause trouble in the future:\n", "If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.\n", " warn(msg)\n" ] } ], "source": [ "import argparse\n", "import json\n", "import logging\n", "import math\n", "import os\n", "import random\n", "from pathlib import Path\n", "from tqdm import tqdm\n", "\n", "import datasets\n", "from datasets import load_dataset, DatasetDict\n", "\n", "import evaluate\n", "import torch\n", "from torch import nn\n", "from torch.utils.data import DataLoader\n", "\n", "import transformers\n", "from transformers import AutoTokenizer, AutoModel, default_data_collator, SchedulerType, get_scheduler\n", "from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry\n", "from transformers.utils.versions import require_version\n", "\n", "from huggingface_hub import Repository, create_repo\n", "\n", "from accelerate import Accelerator\n", "from accelerate.logging import get_logger\n", "from accelerate.utils import set_seed\n", "\n", "from peft import PeftModel\n", "\n", "import hnswlib" ] }, { "cell_type": "code", "execution_count": 2, "id": "c939b4fd", "metadata": {}, "outputs": [], "source": [ "class AutoModelForSentenceEmbedding(nn.Module):\n", " def __init__(self, model_name, tokenizer, normalize=True):\n", " super(AutoModelForSentenceEmbedding, self).__init__()\n", "\n", " self.model = AutoModel.from_pretrained(model_name) # , load_in_8bit=True, device_map={\"\":0})\n", " self.normalize = normalize\n", " self.tokenizer = tokenizer\n", "\n", " def forward(self, **kwargs):\n", " model_output = self.model(**kwargs)\n", " embeddings = self.mean_pooling(model_output, kwargs[\"attention_mask\"])\n", " if self.normalize:\n", " embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)\n", "\n", " return embeddings\n", "\n", " def mean_pooling(self, model_output, attention_mask):\n", " token_embeddings = model_output[0] # First element of model_output contains all token embeddings\n", " input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n", " return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n", "\n", " def __getattr__(self, name: str):\n", " \"\"\"Forward missing attributes to the wrapped module.\"\"\"\n", " try:\n", " return super().__getattr__(name) # defer to nn.Module's logic\n", " except AttributeError:\n", " return getattr(self.model, name)\n", "\n", "\n", "def get_cosing_embeddings(query_embs, product_embs):\n", " return torch.sum(query_embs * product_embs, axis=1)" ] }, { "cell_type": "code", "execution_count": 3, "id": "8b5d9256", "metadata": {}, "outputs": [], "source": [ "model_name_or_path = \"intfloat/e5-large-v2\"\n", "peft_model_id = \"smangrul/peft_lora_e5_semantic_search\"\n", "dataset_name = \"smangrul/amazon_esci\"\n", "max_length = 70\n", "batch_size = 256" ] }, { "cell_type": "code", "execution_count": 4, "id": "f190e1ee", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset parquet (/raid/sourab/.cache/huggingface/datasets/smangrul___parquet/smangrul--amazon_esci-321288cabf0cc045/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "43b84641575e4ce6899a3e6f61d7e126", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import pandas as pd\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n", "dataset = load_dataset(dataset_name)\n", "train_product_dataset = dataset[\"train\"].to_pandas()[[\"product_title\"]]\n", "val_product_dataset = dataset[\"validation\"].to_pandas()[[\"product_title\"]]\n", "product_dataset_for_indexing = pd.concat([train_product_dataset, val_product_dataset])\n", "product_dataset_for_indexing = product_dataset_for_indexing.drop_duplicates()\n", "product_dataset_for_indexing.reset_index(drop=True, inplace=True)\n", "product_dataset_for_indexing.reset_index(inplace=True)" ] }, { "cell_type": "code", "execution_count": 5, "id": "7e52e425", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | index | \n", "product_title | \n", "
---|---|---|
0 | \n", "0 | \n", "RamPro 10\" All Purpose Utility Air Tires/Wheel... | \n", "
1 | \n", "1 | \n", "MaxAuto 2-Pack 13x5.00-6 2PLY Turf Mower Tract... | \n", "
2 | \n", "2 | \n", "NEIKO 20601A 14.5 inch Steel Tire Spoon Lever ... | \n", "
3 | \n", "3 | \n", "2PK 13x5.00-6 13x5.00x6 13x5x6 13x5-6 2PLY Tur... | \n", "
4 | \n", "4 | \n", "(Set of 2) 15x6.00-6 Husqvarna/Poulan Tire Whe... | \n", "
... | \n", "... | \n", "... | \n", "
476273 | \n", "476273 | \n", "Chanel No.5 Eau Premiere Spray 50ml/1.7oz | \n", "
476274 | \n", "476274 | \n", "Steve Madden Designer 15 Inch Carry on Suitcas... | \n", "
476275 | \n", "476275 | \n", "CHANEL Le Lift Creme Yeux, Black, 0.5 Ounce | \n", "
476276 | \n", "476276 | \n", "Coco Mademoiselle by Chanel for Women - 3.4 oz... | \n", "
476277 | \n", "476277 | \n", "Chânél No. 5 by Chânél Eau De Parfum Premiere ... | \n", "
476278 rows × 2 columns
\n", "\n", " | index | \n", "product_title | \n", "
---|---|---|
34710 | \n", "34710 | \n", "ROK 4-1/2 inch Diamond Saw Blade Set, Pack of 3 | \n", "
277590 | \n", "277590 | \n", "WSGG Medical Goggles, FDA registered, Safety Goggles, Fit Over Glasses, Anti-Fog, Anti-Splash (1 pack) | \n", "
474000 | \n", "474000 | \n", "iJDMTOY 15W CREE High Power LED Angel Eye Bulbs Compatible With BMW 5 6 7 Series X3 X5 (E39 E60 E63 E65 E53), 7000K Xenon White Headlight Ring Marker Lights | \n", "
18997 | \n", "18997 | \n", "USB Charger, Anker Elite Dual Port 24W Wall Charger, PowerPort 2 with PowerIQ and Foldable Plug, for iPhone 11/Xs/XS Max/XR/X/8/7/6/Plus, iPad Pro/Air 2/Mini 3/Mini 4, Samsung S4/S5, and More | \n", "
208666 | \n", "208666 | \n", "AOGGY Compatible with MacBook Air 13 inch Case A1466/A1369 (2010-2017 Release) Glitter Fluorescent Color Plastic Hard Case, with Older Version MacBook Air 13 inch Keyboard Cover - Gold | \n", "
326614 | \n", "326614 | \n", "CUTE STONE Little Kitchen Playset, Kitchen Toy Set with Realistic Sound &Light, Play Sink, Cooking Stove with Steam, Play Food and Kitchen Accessories, Great Kitchen Toys for Toddlers Kids | \n", "
105637 | \n", "105637 | \n", "Milwaukee Electric Tool 2470-21 M12 Cordless Shear Kit, 12 V, Li-Ion | \n", "
342392 | \n", "342392 | \n", "chouyatou Women's Short Sleeve/Strap Open Bust Bodysuit Shapewear Firm Control Body Shaper (X-Small, Nude Sleeve) | \n", "
319970 | \n", "319970 | \n", "AMT 256 Hz Medical-Grade Tuning Fork Instrument with Fixed Weights, Non-Magnetic Aluminum Alloy (C 256) | \n", "
416956 | \n", "416956 | \n", "Timberland HIKER-ROUND 54 BROWN | \n", "