Spaces:
Sleeping
Sleeping
AISimplyExplained
commited on
Commit
•
f43f094
1
Parent(s):
c311b0d
Update main.py
Browse files
main.py
CHANGED
@@ -53,6 +53,10 @@ class ToxicityResult(BaseModel):
|
|
53 |
insult: float
|
54 |
identity_attack: float
|
55 |
|
|
|
|
|
|
|
|
|
56 |
class TopicBannerClassifier:
|
57 |
def __init__(self):
|
58 |
self.classifier = pipeline(
|
@@ -99,14 +103,7 @@ topic_banner_classifier = TopicBannerClassifier()
|
|
99 |
async def classify_toxicity(text_prompt: TextPrompt):
|
100 |
try:
|
101 |
result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt)
|
102 |
-
return
|
103 |
-
"toxicity": result['toxicity'],
|
104 |
-
"severe_toxicity": result['severe_toxicity'],
|
105 |
-
"obscene": result['obscene'],
|
106 |
-
"threat": result['threat'],
|
107 |
-
"insult": result['insult'],
|
108 |
-
"identity_attack": result['identity_attack']
|
109 |
-
}
|
110 |
except Exception as e:
|
111 |
raise HTTPException(status_code=500, detail=str(e))
|
112 |
|
|
|
53 |
insult: float
|
54 |
identity_attack: float
|
55 |
|
56 |
+
@classmethod
|
57 |
+
def from_dict(cls, data: dict):
|
58 |
+
return cls(**{k: float(v) for k, v in data.items()})
|
59 |
+
|
60 |
class TopicBannerClassifier:
|
61 |
def __init__(self):
|
62 |
self.classifier = pipeline(
|
|
|
103 |
async def classify_toxicity(text_prompt: TextPrompt):
|
104 |
try:
|
105 |
result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt)
|
106 |
+
return ToxicityResult.from_dict(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
except Exception as e:
|
108 |
raise HTTPException(status_code=500, detail=str(e))
|
109 |
|