feat: add config file and inference script
Browse files- config.json +91 -0
- inference.py +93 -0
config.json
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"description": "Binary classifier on harmful text in Singapore context",
|
3 |
+
"embedding": {
|
4 |
+
"tokenizer": "BAAI/bge-large-en-v1.5",
|
5 |
+
"model": "BAAI/bge-large-en-v1.5",
|
6 |
+
"max_length": 512,
|
7 |
+
"batch_size": 32
|
8 |
+
},
|
9 |
+
"classifier": {
|
10 |
+
"binary": {
|
11 |
+
"calibrated": true,
|
12 |
+
"threshold": {
|
13 |
+
"high_recall": 0.2,
|
14 |
+
"balanced": 0.5,
|
15 |
+
"high_precision": 0.8
|
16 |
+
},
|
17 |
+
"model_type": "ridge_classifier",
|
18 |
+
"model_fp": "models/lionguard-binary.onnx"
|
19 |
+
},
|
20 |
+
"hateful": {
|
21 |
+
"calibrated": false,
|
22 |
+
"threshold": {
|
23 |
+
"high_recall": 0.516,
|
24 |
+
"balanced": 0.827,
|
25 |
+
"high_precision": 1.254
|
26 |
+
},
|
27 |
+
"model_type": "ridge_classifier",
|
28 |
+
"model_fp": "models/lionguard-harassment.onnx"
|
29 |
+
},
|
30 |
+
"harassment": {
|
31 |
+
"calibrated": false,
|
32 |
+
"threshold": {
|
33 |
+
"high_recall": 1.326,
|
34 |
+
"balanced": 1.326,
|
35 |
+
"high_precision": 1.955
|
36 |
+
},
|
37 |
+
"model_type": "ridge_classifier",
|
38 |
+
"model_fp": "models/lionguard-harassment.onnx"
|
39 |
+
},
|
40 |
+
"public_harm": {
|
41 |
+
"calibrated": false,
|
42 |
+
"threshold": {
|
43 |
+
"high_recall": 0.953,
|
44 |
+
"balanced": 0.953,
|
45 |
+
"high_precision": 0.953
|
46 |
+
},
|
47 |
+
"model_type": "ridge_classifier",
|
48 |
+
"model_fp": "models/lionguard-public_harm.onnx"
|
49 |
+
},
|
50 |
+
"self_harm": {
|
51 |
+
"calibrated": false,
|
52 |
+
"threshold": {
|
53 |
+
"high_recall": 0.915,
|
54 |
+
"balanced": 0.915,
|
55 |
+
"high_precision": 0.915
|
56 |
+
},
|
57 |
+
"model_type": "ridge_classifier",
|
58 |
+
"model_fp": "models/lionguard-self_harm.onnx"
|
59 |
+
},
|
60 |
+
"sexual": {
|
61 |
+
"calibrated": false,
|
62 |
+
"threshold": {
|
63 |
+
"high_recall": 0.388,
|
64 |
+
"balanced": 0.500,
|
65 |
+
"high_precision": 0.702
|
66 |
+
},
|
67 |
+
"model_type": "ridge_classifier",
|
68 |
+
"model_fp": "models/lionguard-sexual.onnx"
|
69 |
+
},
|
70 |
+
"toxic": {
|
71 |
+
"calibrated": false,
|
72 |
+
"threshold": {
|
73 |
+
"high_recall": -0.089,
|
74 |
+
"balanced": 0.136,
|
75 |
+
"high_precision": 0.327
|
76 |
+
},
|
77 |
+
"model_type": "ridge_classifier",
|
78 |
+
"model_fp": "models/lionguard-toxic.onnx"
|
79 |
+
},
|
80 |
+
"violent": {
|
81 |
+
"calibrated": false,
|
82 |
+
"threshold": {
|
83 |
+
"high_recall": 0.317,
|
84 |
+
"balanced": 0.981,
|
85 |
+
"high_precision": 0.981
|
86 |
+
},
|
87 |
+
"model_type": "ridge_classifier",
|
88 |
+
"model_fp": "models/lionguard-violent.onnx"
|
89 |
+
}
|
90 |
+
}
|
91 |
+
}
|
inference.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
from transformers import AutoTokenizer, AutoModel
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
import sys
|
7 |
+
import json
|
8 |
+
import onnxruntime as rt
|
9 |
+
|
10 |
+
# Download model config
|
11 |
+
repo_path = "govtech/lionguard-v1"
|
12 |
+
config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
|
13 |
+
with open(config_path, 'r') as f:
|
14 |
+
config = json.load(f)
|
15 |
+
|
16 |
+
def get_embeddings(device, data):
|
17 |
+
|
18 |
+
# Load the model and tokenizer
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained(config['embedding']['tokenizer'])
|
20 |
+
model = AutoModel.from_pretrained(config['embedding']['model'])
|
21 |
+
model.eval()
|
22 |
+
model.to(device)
|
23 |
+
|
24 |
+
# Generate the embeddings
|
25 |
+
batch_size = config['embedding']['batch_size']
|
26 |
+
num_batches = int(np.ceil(len(data)/batch_size))
|
27 |
+
output = []
|
28 |
+
for i in range(num_batches):
|
29 |
+
sentences = data[i*batch_size:(i+1)*batch_size]
|
30 |
+
encoded_input = tokenizer(sentences, max_length=config['embedding']['max_length'], padding=True, truncation=True, return_tensors='pt')
|
31 |
+
encoded_input.to(device)
|
32 |
+
with torch.no_grad():
|
33 |
+
model_output = model(**encoded_input)
|
34 |
+
sentence_embeddings = model_output[0][:, 0]
|
35 |
+
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
36 |
+
output.extend(sentence_embeddings.cpu().numpy())
|
37 |
+
|
38 |
+
return np.array(output)
|
39 |
+
|
40 |
+
def predict(batch_text):
|
41 |
+
|
42 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
43 |
+
embeddings = get_embeddings(device, batch_text)
|
44 |
+
embeddings_df = pd.DataFrame(embeddings)
|
45 |
+
|
46 |
+
# Prepare input data
|
47 |
+
X_input = np.array(embeddings_df, dtype=np.float32)
|
48 |
+
|
49 |
+
# Load the classifiers
|
50 |
+
results = {}
|
51 |
+
for category, details in config['classifier'].items():
|
52 |
+
|
53 |
+
# Download the classifier from HuggingFace hub
|
54 |
+
local_model_fp = hf_hub_download(repo_id = repo_path, filename = config['classifer'][category]['model_fp'])
|
55 |
+
|
56 |
+
# Run the inference
|
57 |
+
session = rt.InferenceSession(local_model_fp)
|
58 |
+
input_name = session.get_inputs()[0].name
|
59 |
+
outputs = session.run(None, {input_name: X_input})
|
60 |
+
|
61 |
+
# If calibrated, return only the prediction for the unsafe class
|
62 |
+
if config['classifier'][category]['calibrated']:
|
63 |
+
scores = [output[1] for output in outputs[1]]
|
64 |
+
|
65 |
+
# If not calibrated, we will only get a 1D array for the unsafe class
|
66 |
+
else:
|
67 |
+
scores = outputs[1].flatten()
|
68 |
+
|
69 |
+
# Generate the predictions depending on the recommended threshold score
|
70 |
+
results[category] = {
|
71 |
+
'scores': scores,
|
72 |
+
'predictions': {
|
73 |
+
'high_recall': [1 if score >= config['classifier'][category]['threshold']['high_recall'] else 0 for score in scores],
|
74 |
+
'balanced': [1 if score >= config['classifier'][category]['threshold']['balanced'] else 0 for score in scores],
|
75 |
+
'high_precision': [1 if score >= config['classifier'][category]['threshold']['high_precision'] else 0 for score in scores]
|
76 |
+
}
|
77 |
+
}
|
78 |
+
|
79 |
+
return results
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
|
83 |
+
# Load the data
|
84 |
+
input_data = sys.argv[1]
|
85 |
+
batch_text = json.loads(input_data)
|
86 |
+
|
87 |
+
# Generate the scores and predictions
|
88 |
+
results = predict(batch_text)
|
89 |
+
for i in range(len(batch_text)):
|
90 |
+
print(f"Text: '{batch_text[i]}'")
|
91 |
+
for category in results.keys():
|
92 |
+
print(f"[Text {i+1}] {category} score: {results[category]['scores'][i]:.3f} | HR: {results[category]['predictions']['high_recall'][i]}, B: {results[category]['predictions']['balanced'][i]}, HP: {results[category]['predictions']['high_precision'][i]}")
|
93 |
+
print('---------------------------------------------')
|