File size: 922 Bytes
c88f06c
 
 
c8709eb
769cc58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8709eb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
---
license: openrail
---
```
import torch
from transformers import AutoTokenizer, MobileBertForSequenceClassification

# Load the saved model
model_name = 'harshith20/Emotion_predictor'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = MobileBertForSequenceClassification.from_pretrained(model_name)

# Tokenize input text
input_text = "I am feeling happy today"
encoded_text = tokenizer.encode_plus(
    input_text,
    max_length=128,
    padding='max_length',
    truncation=True,
    return_attention_mask=True,
    return_tensors='pt'
)

# Predict emotion
with torch.no_grad():
    logits = model(**encoded_text)[0]
    predicted_emotion = torch.argmax(logits).item()
    emotion_labels = ['anger', 'fear', 'joy', 'love', 'sadness', 'surprise']
    predicted_emotion_label = emotion_labels[predicted_emotion]

print(f"Input text: {input_text}")
print(f"Predicted emotion: {predicted_emotion_label}")```