|
from contextlib import asynccontextmanager |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel, ValidationError |
|
from fastapi.encoders import jsonable_encoder |
|
|
|
|
|
|
|
import re |
|
import string |
|
import nltk |
|
nltk.download('punkt') |
|
nltk.download('wordnet') |
|
nltk.download('omw-1.4') |
|
from nltk.stem import WordNetLemmatizer |
|
|
|
|
|
def remove_urls(text): |
|
return re.sub(r'http[s]?://\S+', '', text) |
|
|
|
|
|
def remove_punctuation(text): |
|
regular_punct = string.punctuation |
|
return str(re.sub(r'['+regular_punct+']', '', str(text))) |
|
|
|
|
|
def lower_case(text): |
|
return text.lower() |
|
|
|
|
|
def lemmatize(text): |
|
wordnet_lemmatizer = WordNetLemmatizer() |
|
|
|
tokens = nltk.word_tokenize(text) |
|
lemma_txt = '' |
|
for w in tokens: |
|
lemma_txt = lemma_txt + wordnet_lemmatizer.lemmatize(w) + ' ' |
|
|
|
return lemma_txt |
|
|
|
def preprocess_text(text): |
|
|
|
text = remove_urls(text) |
|
text = remove_punctuation(text) |
|
text = lower_case(text) |
|
text = lemmatize(text) |
|
return text |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
from transformers import pipeline |
|
global sentiment_task |
|
sentiment_task = pipeline("text-classification", model="lxyuan/distilbert-base-multilingual-cased-sentiments-student", tokenizer= "lxyuan/distilbert-base-multilingual-cased-sentiments-student") |
|
|
|
|
|
|
|
|
|
yield |
|
|
|
del sentiment_task |
|
|
|
description = """ |
|
## Text Classification API |
|
Upon input to this app, It will show the sentiment of the text (positive, negative, or neutral). |
|
Check out the docs for the `/analyze/{text}` endpoint below to try it out! |
|
""" |
|
|
|
|
|
app = FastAPI(lifespan=lifespan, docs_url="/", description=description) |
|
|
|
|
|
class TextInput(BaseModel): |
|
text: str |
|
|
|
|
|
@app.get('/') |
|
async def welcome(): |
|
return "Welcome to our First Emotion Classification API" |
|
|
|
|
|
MAX_TEXT_LENGTH = 1000 |
|
|
|
|
|
@app.post('/analyze/{text}') |
|
async def classify_text(text_input:TextInput): |
|
try: |
|
|
|
text_input_dict = jsonable_encoder(text_input) |
|
|
|
text_data = TextInput(**text_input_dict) |
|
|
|
|
|
if len(text_input.text) > MAX_TEXT_LENGTH: |
|
raise HTTPException(status_code=400, detail="Text length exceeds maximum allowed length") |
|
elif len(text_input.text) == 0: |
|
raise HTTPException(status_code=400, detail="Text cannot be empty") |
|
except ValidationError as e: |
|
|
|
raise HTTPException(status_code=422, detail=str(e)) |
|
|
|
try: |
|
|
|
return sentiment_task(preprocess_text(text_input.text)) |
|
except ValueError as ve: |
|
|
|
raise HTTPException(status_code=400, detail=str(ve)) |
|
except Exception as e: |
|
|
|
raise HTTPException(status_code=500, detail=str(e)) |