|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel, Field |
|
from typing import List |
|
from transformers import pipeline |
|
|
|
|
|
classifier = pipeline("zero-shot-classification") |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
class ClassificationRequest(BaseModel): |
|
text: str = Field(..., example="This is a course about the Transformers library") |
|
labels: List[str] = Field(..., example=["education", "politics", "technology"]) |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
""" |
|
A simple GET endpoint that returns a greeting message. |
|
""" |
|
return {"Hello": "World!"} |
|
|
|
@app.post("/classify") |
|
def zero_shot_classification(request: ClassificationRequest): |
|
""" |
|
A POST endpoint that performs zero-shot classification on the input text |
|
using the provided candidate labels. |
|
""" |
|
try: |
|
|
|
result = classifier( |
|
request.text, |
|
candidate_labels=request.labels |
|
) |
|
return result |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|