rifatramadhani's picture
feat: hate speech detection
4edc781
raw
history blame
1.3 kB
import torch
import gradio as gr
import os
from detoxify import Detoxify
import pandas as pd
import json
import spaces
import logging
import datetime
@spaces.GPU
def classify(query):
model = Detoxify("unbiased-small", device="cuda")
all_result = []
request_type = type(query)
try:
data = json.loads(query)
if type(data) != list:
data = [query]
else:
request_type = type(data)
except Exception as e:
print(e)
data = [query]
pass
for i in range(len(data)):
result = {}
start_time = datetime.datetime.now()
df = pd.DataFrame(model.predict(str(data[i])), index=[0])
columns = df.columns
for i, label in enumerate(columns):
result[label] = df[label][0].round(3).astype("float")
end_time = datetime.datetime.now()
elapsed_time = end_time - start_time
result["time"] = str(elapsed_time)
logging.debug("elapsed predict time: %s", str(elapsed_time))
print("elapsed predict time:", str(elapsed_time))
all_result.append(result)
return json.dumps(all_result) if request_type == list else all_result[0]
demo = gr.Interface(fn=classify, inputs=["text"], outputs="text")
demo.launch()