File size: 991 Bytes
c32023c
 
 
1ab4c61
 
 
c32023c
 
338bbe8
bfc39e7
 
 
1ab4c61
 
 
 
 
c32023c
 
 
 
 
 
 
 
 
 
 
fccfd79
c32023c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from fastapi import FastAPI, File, UploadFile
import io
from PIL import Image
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pathlib import Path
import torchvision.transforms as transforms
import mnist_classifier
import torch

app = FastAPI()

app.mount("/static", StaticFiles(directory=Path("static")), name="static")
@app.get("/")
async def root():
    return FileResponse("static/index.html")

def process_image(file: UploadFile):
    image_bytes = file.file.read()
    pil_image = Image.open(io.BytesIO(image_bytes))
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
    ])
    tensor_image = transform(pil_image)
    return tensor_image

@app.post("/predict")
async def predict(image: UploadFile):
    tensor_image = process_image(image)
    prediction = mnist_classifier.predict(tensor_image)
    return {"prediction": prediction}