|
--- |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
|
|
# **TriageTron-large** |
|
|
|
TriageTron-large se trata de un modelo inicializado con [xlnet-large-unsased](https://huggingface.co/xlnet/xlnet-large-cased) y adaptado con preguntas y especialidades |
|
para poder realizar Derivaciones Automatizadas en Servicios Hospitalarios. El dataset se puede encontrar de manera pública y se trata de [MedDialog EN](https://arxiv.org/abs/2004.03329). |
|
|
|
Para el entrenamiento de este modelo hemos seguidos los estándares de la librería [transformers](https://github.com/huggingface/transformers) y se ha utilizado una |
|
NVIDIA P100. Además, ha sido entrenado con un batch-size de 4, un learning rate de 2e-5, X epochs y un weigth decay de 0.015, loggeando los resultados cada 100 iteraciones. |
|
|
|
## **Utilización** |
|
Mediante el `pipeline` de Hugging Face: |
|
```python |
|
from transformers import pipeline |
|
|
|
model_id = "eperezs/TriageTron-large" |
|
pipe = pipeline("text-classification", model_id, token = <your_token_here>) |
|
|
|
text = "I have pain in the back" |
|
result = pipe(text) |
|
|
|
print(result) |
|
``` |
|
|
|
Mediante `AutoModelForSequenceClassification` y `AutoTokenizer`: |
|
```python |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
model_id = "eperezs/TriageTron-large" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_id) |
|
|
|
text = "I have pain in the back" |
|
inputs = tokenizer(text, return_tensors = "pt") |
|
logits = model(**inputs) |
|
|
|
print(f"The predicted class is {model.id2label[logits.argmax()]}") |
|
|
|
print(result) |
|
``` |
|
|
|
## **Evaluación** |
|
- **Accuracy** : 89,3% |
|
- **F1**: 89,4% |
|
- **Precision**: 90% |
|
|
|
|
|
#### Training Hyperparameters |
|
|
|
- **learning_rate**: 2e-5 |
|
- **batch_size**: 4 |
|
- **num_train_epochs**: 3 |
|
- **weight_decay**: 0.015 |
|
- **optimizer**: AdamW |
|
- **test_size**: 0.2 |
|
- **logging_steps**: 100 |