Spaces:
Running
Running
import unittest | |
import numpy as np | |
import torch | |
from sentence_transformers import SentenceTransformer | |
from .encoder_models import SBertEncoder, get_encoder, get_sbert_encoder | |
from .semncg import ( | |
RankedGains, | |
compute_cosine_similarity, | |
compute_gain, | |
score_ncg, | |
compute_ncg, | |
_validate_input_format, | |
SemNCG | |
) | |
from .utils import ( | |
get_gpu, | |
slice_embeddings, | |
is_nested_list_of_type, | |
flatten_list, | |
prep_sentences, | |
tokenize_and_prep_document | |
) | |
class TestUtils(unittest.TestCase): | |
def test_get_gpu(self): | |
gpu_count = torch.cuda.device_count() | |
gpu_available = torch.cuda.is_available() | |
# Test single boolean input | |
self.assertEqual(get_gpu(True), 0 if gpu_available else "cpu") | |
self.assertEqual(get_gpu(False), "cpu") | |
# Test single string input | |
self.assertEqual(get_gpu("cpu"), "cpu") | |
self.assertEqual(get_gpu("gpu"), 0 if gpu_available else "cpu") | |
self.assertEqual(get_gpu("cuda"), 0 if gpu_available else "cpu") | |
# Test single integer input | |
self.assertEqual(get_gpu(0), 0 if gpu_available else "cpu") | |
self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu") | |
# Test list input with unique elements | |
self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"]) | |
# Test list input with duplicate elements | |
self.assertEqual(get_gpu([0, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"]) | |
# Test list input with duplicate elements of different types | |
self.assertEqual(get_gpu([True, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"]) | |
# Test list input but only one element | |
self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu") | |
# Test list input with all integers | |
self.assertEqual(get_gpu(list(range(gpu_count))), | |
list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"]) | |
with self.assertRaises(ValueError): | |
get_gpu("invalid") | |
with self.assertRaises(ValueError): | |
get_gpu(torch.cuda.device_count()) | |
def test_prep_sentences(self): | |
# Test normal case | |
self.assertEqual(prep_sentences(["Hello, world!", " This is a test. ", "!!!"]), | |
['Hello, world!', 'This is a test.']) | |
# Test case with only punctuations | |
with self.assertRaises(ValueError): | |
prep_sentences(["!!!", "..."]) | |
# Test case with empty list | |
with self.assertRaises(ValueError): | |
prep_sentences([]) | |
def test_tokenize_and_prep_document(self): | |
# Test tokenize=True with string input | |
self.assertEqual(tokenize_and_prep_document("Hello, world! This is a test.", True), | |
['Hello, world!', 'This is a test.']) | |
# Test tokenize=False with list of strings input | |
self.assertEqual(tokenize_and_prep_document(["Hello, world!", "This is a test."], False), | |
['Hello, world!', 'This is a test.']) | |
# Test tokenize=True with empty document | |
with self.assertRaises(ValueError): | |
tokenize_and_prep_document("!!! ...", True) | |
def test_slice_embeddings(self): | |
# Case 1 | |
embeddings = np.random.rand(10, 5) | |
num_sentences = [3, 2, 5] | |
expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]] | |
self.assertTrue( | |
all(np.array_equal(a, b) for a, b in zip(slice_embeddings(embeddings, num_sentences), | |
expected_output)) | |
) | |
# Case 2 | |
num_sentences_nested = [[2, 1], [3, 4]] | |
expected_output_nested = [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]] | |
self.assertTrue( | |
slice_embeddings(embeddings, num_sentences_nested), expected_output_nested | |
) | |
# Case 3 | |
document_sentences_count = [10, 8, 7] | |
reference_sentences_count = [5, 3, 2] | |
pred_sentences_count = [2, 2, 1] | |
all_embeddings = np.random.rand( | |
sum(document_sentences_count + reference_sentences_count + pred_sentences_count), 5, | |
) | |
embeddings = all_embeddings | |
expected_doc_embeddings = [embeddings[:10], embeddings[10:18], embeddings[18:25]] | |
embeddings = all_embeddings[25:] | |
expected_ref_embeddings = [embeddings[:5], embeddings[5:8], embeddings[8:10]] | |
embeddings = all_embeddings[35:] | |
expected_pred_embeddings = [embeddings[:2], embeddings[2:4], embeddings[4:5]] | |
doc_embeddings = slice_embeddings(all_embeddings, document_sentences_count) | |
ref_embeddings = slice_embeddings(all_embeddings[sum(document_sentences_count):], reference_sentences_count) | |
pred_embeddings = slice_embeddings( | |
all_embeddings[sum(document_sentences_count + reference_sentences_count):], pred_sentences_count | |
) | |
self.assertTrue(doc_embeddings, expected_doc_embeddings) | |
self.assertTrue(ref_embeddings, expected_ref_embeddings) | |
self.assertTrue(pred_embeddings, expected_pred_embeddings) | |
with self.assertRaises(TypeError): | |
slice_embeddings(embeddings, "invalid") | |
def test_is_nested_list_of_type(self): | |
# Test case: Depth 0, single element matching element_type | |
self.assertEqual(is_nested_list_of_type("test", str, 0), (True, "")) | |
# Test case: Depth 0, single element not matching element_type | |
is_valid, err_msg = is_nested_list_of_type("test", int, 0) | |
self.assertEqual(is_valid, False) | |
# Test case: Depth 1, list of elements matching element_type | |
self.assertEqual(is_nested_list_of_type(["apple", "banana"], str, 1), (True, "")) | |
# Test case: Depth 1, list of elements not matching element_type | |
is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1) | |
self.assertEqual(is_valid, False) | |
# Test case: Depth 0 (Wrong), list of elements matching element_type | |
is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 0) | |
self.assertEqual(is_valid, False) | |
# Depth 2 | |
self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, "")) | |
self.assertEqual(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2), (True, "")) | |
is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) | |
self.assertEqual(is_valid, False) | |
# Depth 3 | |
is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3) | |
self.assertEqual(is_valid, False) | |
self.assertEqual(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3), (True, "")) | |
# Test case: Depth is negative, expecting ValueError | |
with self.assertRaises(ValueError): | |
is_nested_list_of_type([1, 2], int, -1) | |
def test_flatten_list(self): | |
self.assertEqual(flatten_list([1, [2, 3], [[4], 5]]), [1, 2, 3, 4, 5]) | |
self.assertEqual(flatten_list([]), []) | |
self.assertEqual(flatten_list([1, 2, 3]), [1, 2, 3]) | |
self.assertEqual(flatten_list([[[[1]]]]), [1]) | |
class TestSBertEncoder(unittest.TestCase): | |
def setUp(self) -> None: | |
# Set up a test SentenceTransformer model | |
self.model_name = "paraphrase-distilroberta-base-v1" | |
self.sbert_model = get_sbert_encoder(self.model_name) | |
self.device = "cpu" # For testing on CPU | |
self.batch_size = 32 | |
self.verbose = False | |
self.encoder = SBertEncoder(self.sbert_model, self.device, self.batch_size, self.verbose) | |
def test_encode_single_sentence(self): | |
sentence = "Hello, world!" | |
embeddings = self.encoder.encode([sentence]) | |
self.assertEqual(embeddings.shape, (1, 768)) # Adjust shape based on your model's embedding dimension | |
def test_encode_multiple_sentences(self): | |
sentences = ["Hello, world!", "This is a test."] | |
embeddings = self.encoder.encode(sentences) | |
self.assertEqual(embeddings.shape, (2, 768)) # Adjust shape based on your model's embedding dimension | |
def test_get_sbert_encoder(self): | |
model_name = "paraphrase-distilroberta-base-v1" | |
sbert_model = get_sbert_encoder(model_name) | |
self.assertIsInstance(sbert_model, SentenceTransformer) | |
def test_encode_with_gpu(self): | |
if torch.cuda.is_available(): | |
device = "cuda" | |
encoder = get_encoder(self.sbert_model, device, self.batch_size, self.verbose) | |
sentences = ["Hello, world!", "This is a test."] | |
embeddings = encoder.encode(sentences) | |
self.assertEqual(embeddings.shape, (2, 768)) # Adjust shape based on your model's embedding dimension | |
else: | |
self.skipTest("CUDA not available, skipping GPU test.") | |
def test_encode_multi_device(self): | |
if torch.cuda.device_count() < 2: | |
self.skipTest("Multi-GPU test requires at least 2 GPUs.") | |
else: | |
devices = ["cuda:0", "cuda:1"] | |
encoder = get_encoder(self.sbert_model, devices, self.batch_size, self.verbose) | |
sentences = ["This is a test sentence.", "Here is another sentence.", "This is a test sentence."] | |
embeddings = encoder.encode(sentences) | |
self.assertIsInstance(embeddings, np.ndarray) | |
self.assertEqual(embeddings.shape[0], 3) | |
self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension()) | |
class TestGetEncoder(unittest.TestCase): | |
def setUp(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.batch_size = 8 | |
self.verbose = False | |
def _base_test(self, model_name): | |
sbert_model = get_sbert_encoder(model_name) | |
encoder = get_encoder(sbert_model, self.device, self.batch_size, self.verbose) | |
# Assert | |
self.assertIsInstance(encoder, SBertEncoder) | |
self.assertEqual(encoder.device, self.device) | |
self.assertEqual(encoder.batch_size, self.batch_size) | |
self.assertEqual(encoder.verbose, self.verbose) | |
def test_get_sbert_encoder(self): | |
model_name = "stsb-roberta-large" | |
self._base_test(model_name) | |
def test_sbert_model(self): | |
model_name = "all-mpnet-base-v2" | |
self._base_test(model_name) | |
def test_huggingface_model(self): | |
"""Test Huggingface models which work with SBert library""" | |
model_name = "roberta-base" | |
self._base_test(model_name) | |
def test_get_encoder_environment_error(self): # This parameter is used when using patch decorator | |
model_name = "abc" # Wrong model_name | |
with self.assertRaises(EnvironmentError): | |
get_sbert_encoder(model_name) | |
def test_get_encoder_other_exception(self): | |
model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib | |
with self.assertRaises(RuntimeError): | |
get_sbert_encoder(model_name) | |
class TestRankedGainsDataclass(unittest.TestCase): | |
def test_ranked_gains_dataclass(self): | |
# Test initialization and attribute access | |
gt_gains = [("doc1", 0.8), ("doc2", 0.6)] | |
pred_gains = [("doc2", 0.7), ("doc1", 0.5)] | |
k = 2 | |
ncg = 0.75 | |
ranked_gains = RankedGains(gt_gains, pred_gains, k, ncg) | |
self.assertEqual(ranked_gains.gt_gains, gt_gains) | |
self.assertEqual(ranked_gains.pred_gains, pred_gains) | |
self.assertEqual(ranked_gains.k, k) | |
self.assertEqual(ranked_gains.ncg, ncg) | |
class TestComputeCosineSimilarity(unittest.TestCase): | |
def test_compute_cosine_similarity(self): | |
doc_embeds = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) | |
ref_embeds = np.array([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]]) | |
# Test compute_cosine_similarity function | |
similarity_scores = compute_cosine_similarity(doc_embeds, ref_embeds) | |
print(similarity_scores) | |
# Example values, change as per actual function output | |
expected_scores = [0.980, 0.997] | |
self.assertAlmostEqual(similarity_scores[0], expected_scores[0], places=3) | |
self.assertAlmostEqual(similarity_scores[1], expected_scores[1], places=3) | |
class TestComputeGain(unittest.TestCase): | |
def test_compute_gain(self): | |
# Test compute_gain function | |
sim_scores = [0.8, 0.6, 0.7] | |
gains = compute_gain(sim_scores) | |
print(gains) | |
# Example values, change as per actual function output | |
expected_gains = [(0, 0.5), (2, 0.3333333333333333), (1, 0.16666666666666666)] | |
self.assertEqual(gains, expected_gains) | |
class TestScoreNcg(unittest.TestCase): | |
def test_score_ncg(self): | |
# Test score_ncg function | |
model_relevance = [0.8, 0.7, 0.6] | |
gt_relevance = [1.0, 0.9, 0.8] | |
ncg_score = score_ncg(model_relevance, gt_relevance) | |
expected_ncg = 0.778 # Example value, change as per actual function output | |
self.assertAlmostEqual(ncg_score, expected_ncg, places=3) | |
class TestComputeNcg(unittest.TestCase): | |
def test_compute_ncg(self): | |
# Test compute_ncg function | |
pred_gains = [(0, 0.8), (2, 0.7), (1, 0.6)] | |
gt_gains = [(0, 1.0), (1, 0.9), (2, 0.8)] | |
k = 3 | |
ncg_score = compute_ncg(pred_gains, gt_gains, k) | |
expected_ncg = 1.0 # TODO: Confirm this with Dr. Santu | |
self.assertAlmostEqual(ncg_score, expected_ncg, places=6) | |
class TestValidateInputFormat(unittest.TestCase): | |
def test_validate_input_format(self): | |
# Test _validate_input_format function | |
tokenize_sentences = True | |
predictions = ["Prediction 1", "Prediction 2"] | |
references = ["Reference 1", "Reference 2"] | |
documents = ["Document 1", "Document 2"] | |
# No exception should be raised for valid input | |
try: | |
_validate_input_format(tokenize_sentences, predictions, references, documents) | |
except ValueError as e: | |
self.fail(f"_validate_input_format raised ValueError unexpectedly: {str(e)}") | |
# Test invalid input format | |
predictions_invalid = [["Sentence 1 in prediction 1.", "Sentence 2 in prediction 1."], | |
["Sentence 1 in prediction 2.", "Sentence 2 in prediction 2."]] | |
references_invalid = [["Sentences in reference 1."], ["Sentences in reference 2."]] | |
documents_invalid = [["Sentence 1 in document 1.", "Sentence 2 in document 1."], | |
["Sentence 1 in document 2.", "Sentence 2 in document 2."]] | |
with self.assertRaises(ValueError): | |
_validate_input_format(tokenize_sentences, predictions_invalid, references, documents) | |
with self.assertRaises(ValueError): | |
_validate_input_format(tokenize_sentences, predictions, references_invalid, documents) | |
with self.assertRaises(ValueError): | |
_validate_input_format(tokenize_sentences, predictions, references, documents_invalid) | |
class TestSemNCG(unittest.TestCase): | |
def setUp(self): | |
self.model_name = "stsb-distilbert-base" | |
self.metric = SemNCG(self.model_name) | |
def _basic_assertion(self, result, debug: bool = False): | |
self.assertIsInstance(result, tuple) | |
self.assertEqual(len(result), 2) | |
self.assertIsInstance(result[0], float) | |
self.assertTrue(0.0 <= result[0] <= 1.0) | |
self.assertIsInstance(result[1], list) | |
if debug: | |
for ranked_gain in result[1]: | |
self.assertTrue(isinstance(ranked_gain, RankedGains)) | |
self.assertTrue(0.0 <= ranked_gain.ncg <= 1.0) | |
else: | |
for gain in result[1]: | |
self.assertTrue(isinstance(gain, float)) | |
self.assertTrue(0.0 <= gain <= 1.0) | |
def test_compute_basic(self): | |
predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."] | |
references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."] | |
documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."] | |
result = self.metric.compute(predictions=predictions, references=references, documents=documents) | |
self._basic_assertion(result) | |
def test_compute_with_tokenization(self): | |
predictions = [["The cat sat on the mat."], ["The quick brown fox jumps over the lazy dog."]] | |
references = [["A cat was sitting on a mat."], ["A quick brown fox jumped over a lazy dog."]] | |
documents = [["There was a cat on a mat."], ["The quick brown fox jumped over the lazy dog."]] | |
result = self.metric.compute( | |
predictions=predictions, references=references, documents=documents, tokenize_sentences=False | |
) | |
self._basic_assertion(result) | |
def test_compute_with_pre_compute_embeddings(self): | |
predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."] | |
references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."] | |
documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."] | |
result = self.metric.compute( | |
predictions=predictions, references=references, documents=documents, pre_compute_embeddings=True | |
) | |
self._basic_assertion(result) | |
def test_compute_with_debug(self): | |
predictions = ["The cat sat on the mat.", "The quick brown fox jumps over the lazy dog."] | |
references = ["A cat was sitting on a mat.", "A quick brown fox jumped over a lazy dog."] | |
documents = ["There was a cat on a mat.", "The quick brown fox jumped over the lazy dog."] | |
result = self.metric.compute( | |
predictions=predictions, references=references, documents=documents, debug=True | |
) | |
self._basic_assertion(result, debug=True) | |
def test_compute_invalid_input_format(self): | |
predictions = "The cat sat on the mat." | |
references = ["A cat was sitting on a mat."] | |
documents = ["There was a cat on a mat."] | |
with self.assertRaises(ValueError): | |
self.metric.compute(predictions=predictions, references=references, documents=documents) | |
def test_bad_inputs(self): | |
def _call_metric(preds, refs, docs, tok): | |
with self.assertRaises(Exception) as ctx: | |
_ = self.metric.compute( | |
predictions=preds, | |
references=refs, | |
documents=docs, | |
tokenize_sentences=tok, | |
pre_compute_embeddings=True, | |
) | |
print(f"Raised Exception with message: {ctx.exception}") | |
return "" | |
# None Inputs | |
# Case I | |
tokenize_sentences = True | |
predictions = [None] | |
references = ["A cat was sitting on a mat."] | |
documents = ["There was a cat on a mat."] | |
print(f"Case I\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n") | |
# Case II | |
tokenize_sentences = False | |
predictions = [["A cat was sitting on a mat.", None]] | |
references = [["A cat was sitting on a mat.", "A cat was sitting on a mat."]] | |
documents = [["There was a cat on a mat.", "There was a cat on a mat."]] | |
print(f"Case II\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n") | |
# Empty Input | |
tokenize_sentences = True | |
predictions = [] | |
references = ["A cat was sitting on a mat."] | |
documents = ["There was a cat on a mat."] | |
print(f"Case: Empty Input\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n") | |
# Empty String Input | |
tokenize_sentences = True | |
predictions = [""] | |
references = ["A cat was sitting on a mat."] | |
documents = ["There was a cat on a mat."] | |
print(f"Case: Empty String Input\n{_call_metric(predictions, references, documents, tokenize_sentences)}\n") | |
def _test_check_verbose(self): | |
"""UNUSED: previously used to manually check the progress bar | |
This test should not be used since they rely on files that are | |
not kept in version control. this is purely just left here for | |
historical purposes and has the '_' prepended to the function | |
name to avoid being executed. | |
""" | |
import sqlite3 | |
import string | |
con = sqlite3.connect('sem_ncg_samples.db') | |
cur = con.cursor() | |
data = cur.execute( | |
'SELECT * FROM sem_ncg_samples').fetchmany(100) | |
data = list(filter( | |
lambda x: x[0].translate( | |
str.maketrans('', '', string.punctuation) | |
).strip() != '', | |
data | |
)) | |
preds, refs, docs = list(zip(*data)) | |
result = self.metric.compute( | |
predictions=preds, references=refs, | |
documents=docs, verbose=True, | |
gpu=2 | |
) | |
breakpoint() | |
if __name__ == '__main__': | |
unittest.main(verbosity=2) | |