MediFlow / README.md
eperezs's picture
Update README.md
8878172 verified
metadata
library_name: transformers
tags:
  - medical
license: cc-by-nc-nd-4.0
language:
  - en
  - es
metrics:
  - accuracy
  - f1
  - precision
base_model:
  - xlnet/xlnet-large-cased

MediFlow

MediFlow se trata de un modelo inicializado con xlnet-large-cased y adaptado con preguntas y especialidades para poder realizar Derivaciones Automatizadas en Servicios Hospitalarios. El dataset se puede encontrar de manera p煤blica y se trata de MedDialog EN.

Este modelo toma como input una descripci贸n, en ingl茅s, prove铆da por el paciente y devuelve las siguientes especialidades (model.config.label2id): Cardiology, Traumatology, Mental Health y Pneumology. Se puede encontrar m谩s informaci贸n del modelo aqu铆.

Para el entrenamiento de este modelo hemos seguidos los est谩ndares de la librer铆a transformers y se ha utilizado una NVIDIA P100. Adem谩s, ha sido entrenado con un batch-size de 4, un learning rate de 2e-5, X epochs y un weigth decay de 0.015, loggeando los resultados cada 100 iteraciones.

Utilizaci贸n

Mediante el pipeline de Hugging Face:

from transformers import pipeline

model_id = "digitalhealth-healthyliving/MediFlow"
pipe = pipeline("text-classification", model_id)

text = "I have pain in the back"
result = pipe(text)

print(result)

Mediante AutoModelForSequenceClassification y AutoTokenizer:

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_id = "digitalhealth-healthyliving/MediFlow"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)

text = "I have pain in the back"
inputs = tokenizer(text, return_tensors = "pt")
logits = model(**inputs)

print(f"The predicted class is {model.id2label[logits.argmax()]}")

print(result)

Evaluaci贸n

  • Accuracy : 89,3%
  • F1: 89,4%
  • Precision: 90%

Training Hyperparameters

  • learning_rate: 2e-5
  • batch_size: 4
  • num_train_epochs: 3
  • weight_decay: 0.015
  • optimizer: AdamW
  • test_size: 0.2
  • logging_steps: 100