Spaces:
Runtime error
Runtime error
from sentence_transformers import SentenceTransformer, util as st_util | |
from transformers import CLIPModel, CLIPProcessor | |
from PIL import Image | |
import requests | |
import os | |
import torch | |
torch.set_printoptions(precision=10) | |
from tqdm import tqdm | |
import s3fs | |
from io import BytesIO | |
import vector_db | |
"sentence-transformer-clip-ViT-L-14" | |
"openai-clip" | |
model_names = ["fashion"] | |
model_name_to_ids = { | |
"sentence-transformer-clip-ViT-L-14": "clip-ViT-L-14", | |
"fashion": "patrickjohncyh/fashion-clip", | |
"openai-clip": "openai/clip-vit-base-patch32", | |
} | |
AWS_ACCESS_KEY_ID = os.environ["AWS_ACCESS_KEY_ID"] | |
AWS_SECRET_ACCESS_KEY = os.environ["AWS_SECRET_ACCESS_KEY"] | |
# Define your bucket and dataset name. | |
S3_BUCKET = "s3://disco-io" | |
fs = s3fs.S3FileSystem( | |
key=AWS_ACCESS_KEY_ID, | |
secret=AWS_SECRET_ACCESS_KEY, | |
) | |
ROOT_DATA_PATH = os.path.join(S3_BUCKET, 'data') | |
def get_data_path(): | |
return os.path.join(ROOT_DATA_PATH, cur_dataset) | |
def get_image_path(): | |
return os.path.join(get_data_path(), 'images') | |
def get_metadata_path(): | |
return os.path.join(get_data_path(), 'metadata') | |
def get_embeddings_path(): | |
return os.path.join(get_metadata_path(), cur_dataset + '_embeddings.pq') | |
model_dict = dict() | |
def download_to_s3(url, s3_path): | |
# Download the file from the URL | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
# Upload the file to the S3 path | |
with fs.open(s3_path, "wb") as s3_file: | |
for chunk in response.iter_content(chunk_size=8192): | |
s3_file.write(chunk) | |
def remove_all_files_from_s3_directory(s3_directory): | |
# List all objects in the S3 directory | |
objects = fs.ls(s3_directory) | |
# Remove each object | |
for obj in objects: | |
try: | |
fs.rm(obj) | |
except: | |
print('Error removing file: ' + obj) | |
def download_images(df, img_folder): | |
remove_all_files_from_s3_directory(img_folder) | |
for index, row in df.iterrows(): | |
try: | |
download_to_s3(row['IMG_URL'], os.path.join(img_folder, | |
row['title'].replace('/', '_').replace('\n', '') + '.jpg')) | |
except: | |
print('Error downloading image: ' + str(index) + row['title']) | |
def load_models(): | |
for model_name in model_name_to_ids: | |
if model_name not in model_dict: | |
model_dict[model_name] = dict() | |
if model_name.startswith('sentence-transformer'): | |
model_dict[model_name]['model'] = SentenceTransformer(model_name_to_ids[model_name]) | |
else: | |
model_dict[model_name]['hf_dir'] = model_name_to_ids[model_name] | |
model_dict[model_name]['model'] = CLIPModel.from_pretrained(model_name_to_ids[model_name]) | |
model_dict[model_name]['processor'] = CLIPProcessor.from_pretrained(model_name_to_ids[model_name]) | |
if len(model_dict) == 0: | |
print('Loading models...') | |
load_models() | |
def get_image_embedding(model_name, image): | |
""" | |
Takes an image as input and returns an embedding vector. | |
""" | |
model = model_dict[model_name]['model'] | |
if model_name.startswith('sentence-transformer'): | |
return model.encode(image) | |
else: | |
inputs = model_dict[model_name]['processor'](images=image, return_tensors="pt") | |
image_features = model.get_image_features(**inputs).detach().numpy()[0] | |
return image_features | |
def s3_path_to_image(fs, s3_path): | |
""" | |
Takes an S3 path as input and returns a PIL Image object. | |
Args: | |
s3_path (str): The path to the image in the S3 bucket, including the bucket name (e.g., "bucket_name/path/to/image.jpg"). | |
Returns: | |
Image: A PIL Image object. | |
""" | |
with fs.open(s3_path, "rb") as f: | |
image_data = BytesIO(f.read()) | |
img = Image.open(image_data) | |
return img | |
def generate_and_save_embeddings(): | |
# Get image embeddings | |
with torch.no_grad(): | |
for fp in tqdm(fs.ls(get_image_path()), desc="Generate embeddings for Images"): | |
if fp.endswith('.jpg'): | |
name = fp.split('/')[-1] | |
for model_name in model_name_to_ids.keys(): | |
s3_path = 's3://' + fp | |
vector_db.add_image_embedding_to_db( | |
embedding=get_image_embedding(model_name, s3_path_to_image(fs, s3_path)), | |
model_name=model_name, | |
dataset_name=cur_dataset, | |
path_to_image=s3_path, | |
image_name=name, | |
) | |
def get_immediate_subdirectories(s3_path): | |
return [obj.split('/')[-1] for obj in fs.glob(f"{s3_path}/*") if fs.isdir(obj)] | |
all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH) | |
cur_dataset = all_datasets[0] | |
def set_cur_dataset(dataset): | |
refresh_all_datasets() | |
print(f"Setting current dataset to {dataset}") | |
global cur_dataset | |
cur_dataset = dataset | |
def refresh_all_datasets(): | |
global all_datasets | |
all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH) | |
print(f"Refreshing all datasets: {all_datasets}") | |
def url_to_image(url): | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
img = Image.open(BytesIO(response.content)) | |
return img | |
except requests.exceptions.RequestException as e: | |
print(f"Error fetching image from URL: {url}") | |
return None |