Edit model card

ONNX model - a fine tuned version of DistilBERT which can be used to classify text as one of:

  • neutral, offensive_language, harmful_behaviour, hate_speech

The model was trained using the csfy tool and the dataset seanius/toxic-or-neutral-text-labelled

The base model is required (distilbert-base-uncased)

For an example of how to run the model, see below - or see the csfy tool.

The output is a number indicating the class - it is decoded via the label_mapping.json file.

Usage

# Loading the label mappings
import json
def load_label_mappings():
    with open("./label_mapping.json", encoding="utf-8") as f:
        data = json.load(f)
        return data['labels']

label_mappings = load_label_mappings()

# Loading the model
import onnxruntime as ort
from transformers import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
ort_session = ort.InferenceSession("./toxic-or-neutral-text-labelled.onnx")

# Predicting label for given text
def predict_via_onnx(text, ort_session, tokenizer, label_mappings):
    model_expected_input_shape = ort_session.get_inputs()[0].shape
    print("Model expects input shape:", model_expected_input_shape)
    inputs = tokenizer(text, return_tensors="np", padding="max_length", truncation=True, max_length=model_expected_input_shape[1])
    print("input shape", inputs['input_ids'].shape)

    input_ids = inputs['input_ids']
    if input_ids.ndim == 1:
        input_ids = input_ids[np.newaxis, :]
    ort_inputs = {ort_session.get_inputs()[0].name: input_ids}

    ort_inputs['input_ids'] = ort_inputs['input_ids'].astype(np.int64)

    ort_outputs = ort_session.run(None, ort_inputs)
    predictions = np.argmax(ort_outputs, axis=-1)

    predicted_label = label_mappings[predictions.item()]
    return predicted_label

predicted_label = predict_via_onnx("How do I get to the beach?", ort_session, tokenizer, label_mappings)
print(predicted_label)

license: mit

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for seanius/toxic-or-neutral-text-labelled

Quantized
(22)
this model

Dataset used to train seanius/toxic-or-neutral-text-labelled