Roberta for German text Classification
This is a xlm Roberta model finetuned on a German Discourse dataset of 60 discourses having a total over 10k sentences.
Understanding the labels
Externalization: Emphasize situational factors that we dont have control over as the cause of behavior. For example "I had a really tough day at work and then when I got home, my cat got sick. It's just been one thing after another and it's really getting to me.".
Elicitation: Emphasize the role of the listener by asking questions or providing prompts. For example "Can you tell me more about what it feels like when you're anxious?".
Conflict: Attribute each other's behavior to dispositional factors (such as being short-sighted or inflexible). For example "You're not thinking about the big picture here!".
Acceptance: Accept the perspectives or experiences of others. For example "It sounds like you had a really hard day".
Integration: Combining multiple perspectives to create a more comprehensive understanding of the behavior of others. For example "What if we combined elements of both proposals to create something that incorporates the best of both worlds?".
How to use the model
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def get_label(sentence):
vectors = tokenizer(sentence, return_tensors='pt').to(device)
outputs = bert_model(**vectors).logits
probs = torch.nn.functional.softmax(outputs, dim = 1)[0]
bert_dict = {}
keys = ['Externalization', 'Elicitation', 'Conflict', 'Acceptence', 'Integration', 'None']
for i in range(len(keys)):
bert_dict[keys[i]] = round(probs[i].item(), 3)
return bert_dict
MODEL_NAME = 'RashidNLP/German-Text-Classification'
MODEL_DIR = 'model'
CHECKPOINT_DIR = 'checkpoints'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
OUTPUTS = 6
bert_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels = OUTPUTS).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
get_label("Gehst du zum Oktoberfest?")
- Downloads last month
- 24