Can i use SentenceTransformers instead of the presented code?
I wanted to use the SentenceTransformers instead of their sample code, as it works well with my current setup, but had to figure out whether there would be a difference in results.
It turns out that even though both use mean pooling and transformers as backend, the results are slightly different.
from transformers import AutoTokenizer, AutoModel
import torch
from sentence_transformers import SentenceTransformer
import numpy as np
def load_finetuned_model():
sentence_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-sentence")
query_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-query")
tokenizer = AutoTokenizer.from_pretrained("biu-nlp/abstract-sim-sentence")
return tokenizer, query_encoder, sentence_encoder
def encode_batch(model, tokenizer, sentences: List[str], device: str):
input_ids = tokenizer(
sentences,
padding=True,
max_length=512,
truncation=True,
return_tensors="pt",
add_special_tokens=True,
).to(device)
features = model(**input_ids)[0]
features = torch.sum(
features[:, 1:, :] * input_ids["attention_mask"][:, 1:].unsqueeze(-1), dim=1
) / torch.clamp(
torch.sum(input_ids["attention_mask"][:, 1:], dim=1, keepdims=True), min=1e-9
)
return features
tokenizer, query_encoder, sentence_encoder = load_finetuned_model()
relevant_sentences = [
"Fingersoft's parent company is the Finger Group.",
"WHIRC β a subsidiary company of Wright-Hennepin",
"CK Life Sciences International (Holdings) Inc. (), or CK Life Sciences, is a subsidiary of CK Hutchison Holdings",
"EM Microelectronic-Marin (subsidiary of The Swatch Group).",
"The company is currently a division of the corporate group Jam Industries.",
"Volt Technical Resources is a business unit of Volt Workforce Solutions, a subsidiary of Volt Information Sciences (currently trading over-the-counter as VISI.).",
]
irrelevant_sentences = [
"The second company is deemed to be a subsidiary of the parent company.",
"The company has gone through more than one incarnation.",
"The company is owned by its employees.",
"Larger companies compete for market share by acquiring smaller companies that may own a particular market sector.",
"A parent company is a company that owns 51% or more voting stock in another firm (or subsidiary).",
"It is a holding company that provides services through its subsidiaries in the following areas: oil and gas, industrial and infrastructure, government and power.",
"RXVT Technologies is no longer a subsidiary of the parent company.",
]
all_sentences = relevant_sentences + irrelevant_sentences
embeddings = (
encode_batch(sentence_encoder, tokenizer, all_sentences, "cpu")
.detach()
.cpu()
.numpy()
)
sentence_transformer = SentenceTransformer("biu-nlp/abstract-sim-sentence")
sentence_transformer_embeddings = sentence_transformer.encode(all_sentences, normalize_embeddings=False)
print(np.linalg.norm(embeddings, axis=1))
print(np.linalg.norm(a, axis=1))
print(np.linalg.norm(embeddings - a, axis=1))
print(np.diag(cosine_similarity(embeddings, a)))
length of embeddings authors method [3.191418 3.2790325 3.283971 3.1165273 3.1817975 3.1611388 3.0702376
2.8644533 3.1743984 3.2016773 3.086787 3.2246523 3.2988307]
length of embeddings sentence transformers [3.1486273 3.2397876 3.2651744 3.088796 3.1495209 3.1484847 3.0348575
2.8282485 3.1296148 3.1790628 3.0645113 3.2098253 3.2714853]
difference in length [0.09617972 0.082573 0.04128565 0.0691056 0.08182564 0.03538151
0.07103401 0.09275693 0.11129349 0.05690338 0.04796569 0.03957536
0.07060497]
cosine similarity between the same inputs [0.99963075 0.99975157 0.99993694 0.999792 0.99971795 0.9999451
0.99979633 0.99954975 0.9994776 0.99986607 0.99990445 0.9999351
0.99980366]
As can be seen from the above diagnostics the results produced are slightly different. Hard to say whether it will affect performance.
Hi,
This probably stems from special tokens / the fact we do not mean-pool over the CLS token. We hope to support SentenceTransformers in the future.
Thank you for the quick response. I have confirmed that indeed including the CLS token will provide almost identical results to SentenceTransformers. Below a modified version of the authors encoding method to include CLS token in the mean pooling.
def encode_batch_include_cls(model, tokenizer, sentences: List[str], device: str):
input_ids = tokenizer(
sentences,
padding=True,
max_length=512,
truncation=True,
return_tensors="pt",
add_special_tokens=True,
).to(device)
features = model(**input_ids)[0]
features = torch.sum(
features * input_ids["attention_mask"].unsqueeze(-1), dim=1
) / torch.clamp(
torch.sum(input_ids["attention_mask"], dim=1, keepdims=True), min=1e-9
)
return features
With the following results:
difference in length with cls [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
cosine similarity between the same inputs with cls [0.99999964 0.9999998 0.9999999 1.0000002 0.99999994 0.9999999 1. 1.0000001 0.9999999 1.0000001 0.9999999 1.0000001 1. ]
Hard to say what effects inclusion or exclusion of CLS token will have on performance.