Fasika commited on
Commit
ede4edf
·
1 Parent(s): eebd998
Files changed (2) hide show
  1. app.py +32 -18
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,24 +1,38 @@
1
- from fastapi import FastAPI
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
4
 
5
- checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
6
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
7
- model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
8
-
9
- sequences = ["I've been waiting for a HuggingFace course my whole life.", "So have I!"]
10
-
11
- tokens = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")
12
-
13
- # Perform inference without gradient tracking
14
- with torch.no_grad():
15
- output = model(**tokens)
16
-
17
- # Convert logits to a list for JSON serialization
18
- logits = output.logits.tolist()
19
 
 
20
  app = FastAPI()
21
 
 
 
 
 
 
22
  @app.get("/")
23
  def greet_json():
24
- return {"Hello": logits}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel, Field
3
+ from typing import List
4
+ from transformers import pipeline
5
 
6
+ # Initialize the zero-shot classification pipeline
7
+ classifier = pipeline("zero-shot-classification")
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Define the FastAPI application
10
  app = FastAPI()
11
 
12
+ # Pydantic model for input validation
13
+ class ClassificationRequest(BaseModel):
14
+ text: str = Field(..., example="This is a course about the Transformers library")
15
+ labels: List[str] = Field(..., example=["education", "politics", "technology"])
16
+
17
  @app.get("/")
18
  def greet_json():
19
+ """
20
+ A simple GET endpoint that returns a greeting message.
21
+ """
22
+ return {"Hello": "World!"}
23
+
24
+ @app.post("/classify")
25
+ def zero_shot_classification(request: ClassificationRequest):
26
+ """
27
+ A POST endpoint that performs zero-shot classification on the input text
28
+ using the provided candidate labels.
29
+ """
30
+ try:
31
+ # Perform zero-shot classification
32
+ result = classifier(
33
+ request.text,
34
+ candidate_labels=request.labels
35
+ )
36
+ return result
37
+ except Exception as e:
38
+ raise HTTPException(status_code=500, detail=str(e))
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  fastapi
2
  uvicorn[standard]
3
  torch
4
- transformers
 
 
1
  fastapi
2
  uvicorn[standard]
3
  torch
4
+ transformers
5
+ pydantic