Overview
This repository contains the bert_base_uncased_rxnorm_babbage model, a continually pretrained Bert-base-uncased model with drugs, diseases, and their relationships from RxNorm using masked language modeling. We hypothesize that the augmentation can boost the model's understanding of medical terminologies and contexts.
It uses a corpus comprising approximately 8.8M million tokens sythesized using drug and disease relations harvested from RxNorm. A few exampes show below.
ferrous fumarate 191 MG is contraindicated with Hemochromatosis.
24 HR metoprolol succinate 50 MG Extended Release Oral Capsule [Kapspargo] contains the ingredient Metoprolol.
Genvoya has the established pharmacologic class Cytochrome P450 3A Inhibitor.
cefprozil 250 MG Oral Tablet may be used to treat Haemophilus Infections.
mecobalamin 1 MG Sublingual Tablet contains the ingredient Vitamin B 12.
The dataset is hosted at this commit. Note, this is the babbage version of the corpus using all drug and disease relations. Don't confuse it with the ada version, where only a fraction of the relationships are used (see the repo for more information).
Training
15% of the data was masked for prediction. The model processes this data for 20 epochs. Training happens on 4 A40(48G) using python3.8 (tried to match up dependencies specified at requirements.txt). It has a batch size of 16 and a learning rate of 5e-5. See more configuration at GitHub and training curves at WandB.
Usage
You can use this model for masked language modeling tasks to predict missing words in a given text. Below are the instructions and examples to get you started.
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
# load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Su-informatics-lab/bert_base_uncased_rxnorm_babbage")
model = AutoModelForMaskedLM.from_pretrained("Su-informatics-lab/bert_base_uncased_rxnorm_babbage")
# prepare the input
text = "0.05 ML aflibercept 40 MG/ML Injection is contraindicated with [MASK]."
inputs = tokenizer(text, return_tensors="pt")
# get model predictions
with torch.no_grad():
outputs = model(**inputs)
# decode the predictions
predictions = outputs.logits
masked_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
predicted_token_id = predictions[0, masked_index].argmax(axis=-1)
predicted_token = tokenizer.decode(predicted_token_id)
License
Apache 2.0.
- Downloads last month
- 13