Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import os | |
from detoxify import Detoxify | |
import pandas as pd | |
import json | |
import spaces | |
import logging | |
import datetime | |
def classify(query): | |
model = Detoxify("unbiased-small", device="cuda") | |
all_result = [] | |
request_type = type(query) | |
try: | |
data = json.loads(query) | |
if type(data) != list: | |
data = [query] | |
else: | |
request_type = type(data) | |
except Exception as e: | |
print(e) | |
data = [query] | |
pass | |
for i in range(len(data)): | |
result = {} | |
start_time = datetime.datetime.now() | |
df = pd.DataFrame(model.predict(str(data[i])), index=[0]) | |
columns = df.columns | |
for i, label in enumerate(columns): | |
result[label] = df[label][0].round(3).astype("float") | |
end_time = datetime.datetime.now() | |
elapsed_time = end_time - start_time | |
result["time"] = str(elapsed_time) | |
logging.debug("elapsed predict time: %s", str(elapsed_time)) | |
print("elapsed predict time:", str(elapsed_time)) | |
all_result.append(result) | |
return json.dumps(all_result) if request_type == list else all_result[0] | |
demo = gr.Interface(fn=classify, inputs=["text"], outputs="text") | |
demo.launch() | |