File size: 2,153 Bytes
3ac99d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import onnxruntime as ort
from transformers import AutoTokenizer
import gradio as gr

# Define available models with their ONNX file paths and tokenizer names
models = {
    "DistilBERT": {
        "onnx_model_path": "distilbert.onnx",
        "tokenizer_name": "distilbert-base-multilingual-cased",
    },
    "BERT": {
        "onnx_model_path": "bert.onnx",
        "tokenizer_name": "bert-base-multilingual-cased",
    },
    "MuRIL": {
        "onnx_model_path": "muril.onnx",
        "tokenizer_name": "google/muril-base-cased",
    },
    "RoBERTa": {
        "onnx_model_path": "roberta.onnx",
        "tokenizer_name": "cardiffnlp/twitter-roberta-base-emotion",
    },
}

# Load models and tokenizers into memory
model_sessions = {}
tokenizers = {}

for model_name, config in models.items():
    print(f"Loading {model_name}...")
    model_sessions[model_name] = ort.InferenceSession(config["onnx_model_path"])
    tokenizers[model_name] = AutoTokenizer.from_pretrained(config["tokenizer_name"])

print("All models loaded!")

# Prediction function
def predict_with_model(text, model_name):
    # Select the appropriate ONNX session and tokenizer
    ort_session = model_sessions[model_name]
    tokenizer = tokenizers[model_name]

    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True)

    # Run ONNX inference
    outputs = ort_session.run(None, {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
    })

    # Post-process the output
    logits = outputs[0]
    label = "Hate Speech" if logits[0][1] > logits[0][0] else "Not Hate Speech"
    return label

# Define Gradio interface
interface = gr.Interface(
    fn=predict_with_model,
    inputs=[
        gr.Textbox(label="Enter text to classify"),
        gr.Dropdown(
            choices=list(models.keys()),
            label="Select a model",
        ),
    ],
    outputs="text",
    title="Multi-Model Hate Speech Detection",
    description="Choose a model and enter text to classify whether it's hate speech.",
)

# Launch the app
if __name__ == "__main__":
    interface.launch()