Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
from fastapi import FastAPI
|
2 |
-
from pydantic import BaseModel
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch
|
@@ -7,25 +6,8 @@ from torchvision import transforms
|
|
7 |
|
8 |
from typing import Any, Type
|
9 |
|
10 |
-
import pydantic
|
11 |
import torch
|
12 |
|
13 |
-
class TensorSchema(pydantic.BaseModel):
|
14 |
-
shape: list[int]
|
15 |
-
dtype: str
|
16 |
-
requires_grad: bool
|
17 |
-
|
18 |
-
def __get_pydantic_core_schema__(cls: Type[Any]) -> pydantic.schema.Schema:
|
19 |
-
return pydantic.schema.Schema(
|
20 |
-
type="object",
|
21 |
-
properties={
|
22 |
-
"shape": pydantic.schema.Schema(type="array", items=pydantic.schema.Schema(type="integer")),
|
23 |
-
"dtype": pydantic.schema.Schema(type="string"),
|
24 |
-
"requires_grad": pydantic.schema.Schema(type="boolean"),
|
25 |
-
},
|
26 |
-
required=["shape", "dtype", "requires_grad"],
|
27 |
-
)
|
28 |
-
|
29 |
class TorchTensor(torch.Tensor):
|
30 |
pass
|
31 |
|
@@ -87,7 +69,7 @@ async def predict_endpoint(input: fastapi.File):
|
|
87 |
predicted_class = prediction.argmax(1)
|
88 |
|
89 |
# Return the prediction.
|
90 |
-
return
|
91 |
|
92 |
|
93 |
if __name__ == "__main__":
|
|
|
1 |
from fastapi import FastAPI
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch
|
|
|
6 |
|
7 |
from typing import Any, Type
|
8 |
|
|
|
9 |
import torch
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
class TorchTensor(torch.Tensor):
|
12 |
pass
|
13 |
|
|
|
69 |
predicted_class = prediction.argmax(1)
|
70 |
|
71 |
# Return the prediction.
|
72 |
+
return {"prediction": predicted_class.item()}
|
73 |
|
74 |
|
75 |
if __name__ == "__main__":
|