Spaces:
Running
Running
import os | |
import torch | |
import numpy as np | |
from colbert.infra import ColBERTConfig | |
from colbert.modeling.checkpoint import Checkpoint | |
class ColBERT: | |
def __init__(self, name, **kwargs) -> None: | |
print("ColBERT: Loading model", name) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
DOCKER = kwargs.get("env") == "docker" | |
if DOCKER: | |
# This is a workaround for the issue with the docker container | |
# where the torch extension is not loaded properly | |
# and the following error is thrown: | |
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory | |
lock_file = ( | |
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock" | |
) | |
if os.path.exists(lock_file): | |
os.remove(lock_file) | |
self.ckpt = Checkpoint( | |
name, | |
colbert_config=ColBERTConfig(model_name=name), | |
).to(self.device) | |
pass | |
def calculate_similarity_scores(self, query_embeddings, document_embeddings): | |
query_embeddings = query_embeddings.to(self.device) | |
document_embeddings = document_embeddings.to(self.device) | |
# Validate dimensions to ensure compatibility | |
if query_embeddings.dim() != 3: | |
raise ValueError( | |
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}." | |
) | |
if document_embeddings.dim() != 3: | |
raise ValueError( | |
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}." | |
) | |
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]: | |
raise ValueError( | |
"There should be either one query or queries equal to the number of documents." | |
) | |
# Transpose the query embeddings to align for matrix multiplication | |
transposed_query_embeddings = query_embeddings.permute(0, 2, 1) | |
# Compute similarity scores using batch matrix multiplication | |
computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings) | |
# Apply max pooling to extract the highest semantic similarity across each document's sequence | |
maximum_scores = torch.max(computed_scores, dim=1).values | |
# Sum up the maximum scores across features to get the overall document relevance scores | |
final_scores = maximum_scores.sum(dim=1) | |
normalized_scores = torch.softmax(final_scores, dim=0) | |
return normalized_scores.detach().cpu().numpy().astype(np.float32) | |
def predict(self, sentences): | |
query = sentences[0][0] | |
docs = [i[1] for i in sentences] | |
# Embedding the documents | |
embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0] | |
# Embedding the queries | |
embedded_queries = self.ckpt.queryFromText([query], bsize=32) | |
embedded_query = embedded_queries[0] | |
# Calculate retrieval scores for the query against all documents | |
scores = self.calculate_similarity_scores( | |
embedded_query.unsqueeze(0), embedded_docs | |
) | |
return scores | |