--- library_name: transformers license: apache-2.0 datasets: - jaeyong2/Thai-emb-PreView language: - th base_model: - Alibaba-NLP/gte-multilingual-base --- # Model Card for Model ID ## Model Details ## Train - H/W : colab A100 40GB - Data : jaeyong2/Thai-emb-PreView ``` model_name = "Alibaba-NLP/gte-multilingual-base" dataset = datasets.load_dataset("jaeyong2/Thai-emb-PreView") train_dataloader = DataLoader(dataset['train'], batch_size=8, shuffle=True) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name).to(torch.bfloat16) triplet_loss = TripletLoss(margin=1.0) optimizer = AdamW(model.parameters(), lr=5e-5) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) for epoch in range(3): model.train() total_loss = 0 count = 0 for batch in tqdm(train_dataloader): optimizer.zero_grad() loss = None for index in range(len(batch["context"])): anchor_encodings = tokenizer([batch["context"][index]], truncation=True, padding="max_length", max_length=4096, return_tensors="pt") positive_encodings = tokenizer([batch["Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt") negative_encodings = tokenizer([batch["Fake Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt") anchor_encodings = batch_to_device(anchor_encodings, device) positive_encodings = batch_to_device(positive_encodings, device) negative_encodings = batch_to_device(negative_encodings, device) anchor_output = model(**anchor_encodings)[0][:, 0, :] positive_output = model(**positive_encodings)[0][:, 0, :] negative_output = model(**negative_encodings)[0][:, 0, :] if loss==None: loss = triplet_loss(anchor_output, positive_output, negative_output) else: loss += triplet_loss(anchor_output, positive_output, negative_output) loss /= len(batch["context"]) loss.backward() optimizer.step() ``` ## Evaluation Code : ``` import torch import numpy as np from sklearn.metrics import pairwise_distances from tqdm import tqdm dataset = datasets.load_dataset("jaeyong2/Thai-emb-PreView") validation_dataset = dataset["test"].select(range((1000))) model.eval() def evaluate(validation_dataset): correct_count = 0 for item in tqdm(validation_dataset): query_embedding = get_embedding(item["context"], model, tokenizer) document_embedding = get_embedding(item["Title"], model, tokenizer) negative_embedding = get_embedding(item["Fake Title"], model, tokenizer) positive_distances = pairwise_distances(query_embedding.detach().cpu().float().numpy(), document_embedding.detach().cpu().float().numpy(), metric="cosine") negative_distances = pairwise_distances(query_embedding.detach().cpu().float().numpy(), negative_embedding.detach().cpu().float().numpy(), metric="cosine") if positive_distances < negative_distances: correct_count += 1 accuracy = correct_count / len(validation_dataset) return accuracy results = evaluate(validation_dataset) print(f"Validation Results: {results}") ``` Accuracy - Alibaba-NLP/gte-multilingual-base : 0.953 - jaeyong2/gte-multilingual-base-Thai-embedding : 0.991 ### License - Alibaba-NLP/gte-multilingual-base : https://choosealicense.com/licenses/apache-2.0/