|
from typing import Dict, List, Any |
|
from torchvision.models import resnet18, ResNet18_Weights |
|
from torchvision.io import read_image |
|
from PIL import Image |
|
import io |
|
import requests |
|
import torchvision.transforms.functional as transform |
|
|
|
from torch2trt import torch2trt |
|
from torchvision.models.alexnet import alexnet |
|
import torch |
|
|
|
|
|
model = alexnet(pretrained=True).eval().cuda() |
|
|
|
|
|
x = torch.ones((1, 3, 224, 224)).cuda() |
|
|
|
|
|
model_trt = torch2trt(model, [x]) |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
weights = ResNet18_Weights.DEFAULT |
|
|
|
model = resnet18(weights=weights).eval().cuda() |
|
|
|
|
|
x = torch.ones((1, 3, 224, 224)).cuda() |
|
|
|
|
|
self.pipeline = torch2trt(model, [x]) |
|
self.preprocess = weights.transforms() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.pop("inputs",data) |
|
if inputs.startswith("http") or inputs.startswith("www"): |
|
response = requests.get(inputs).content |
|
img = transform.to_tensor(Image.open(io.BytesIO(response))) |
|
else: |
|
img = read_image(inputs) |
|
|
|
batch = self.preprocess(img).unsqueeze(0) |
|
prediction = self.pipeline(batch).squeeze(0).softmax(0) |
|
|
|
return prediction.tolist() |