Generalized Contrastive Learning for Multi-Modal Retrieval and Ranking

This work aims to improve and measure the ranking performance of information retrieval models, especially for retrieving relevant products given a search query.

Blog post: https://www.marqo.ai/blog/generalized-contrastive-learning-for-multi-modal-retrieval-and-ranking

Paper: https://arxiv.org/pdf/2404.08535.pdf

Text-only

Methods Models nDCG ERR RBP
BM25 - 0.071 0.028 0.052
E5 e5-large-v2 0.335 0.095 0.289
E5 (GCL) e5-large-v2 0.470 0.457 0.374

Usage

import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel


def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


# Each input text should start with "query: " or "passage: ".
# For tasks other than retrieval, you can simply use the "query: " prefix.
input_texts = ['query: Espresso Pitcher with Handle',
               'query: Women’s designer handbag sale',
               "passage: Dianoo Espresso Steaming Pitcher, Espresso Milk Frothing Pitcher Stainless Steel",
               "passage: Coach Outlet Eliza Shoulder Bag - Black - One Size"]

tokenizer = AutoTokenizer.from_pretrained('Marqo/marqo-gcl-e5-large-v2-130')
model_new = AutoModel.from_pretrained('Marqo/marqo-gcl-e5-large-v2-130')

# Tokenize the input texts
batch_dict = tokenizer(input_texts, max_length=77, padding=True, truncation=True, return_tensors='pt')

outputs = model_new(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist())
Downloads last month
100
Safetensors
Model size
335M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Collection including Marqo/marqo-gcl-e5-large-v2-130