import json from transformers import AutoModel, AutoTokenizer from tqdm import tqdm import srsly import typer def load_data(data_path, sample_size): with open(data_path) as f: data = json.loads(f.read()) return data def tag(data_path, tagged_data_path, sample_size: int = 10): data = srsly.read_jsonl(data_path) data = [next(data) for _ in range(sample_size)] tokenizer = AutoTokenizer.from_pretrained("Wellcome/WellcomeBertMesh") model = AutoModel.from_pretrained( "Wellcome/WellcomeBertMesh", trust_remote_code=True ) texts = [grant["title_and_description"] for grant in data] for batch_index in tqdm(range(0, len(texts), 10)): batch_texts = texts[batch_index : batch_index + 10] inputs = tokenizer(batch_texts, padding="max_length") labels = model(**inputs, return_labels=True) for i, tags in enumerate(labels): data[batch_index + i]["tags"] = tags srsly.write_jsonl(tagged_data_path, data) if __name__ == "__main__": typer.run(tag)