SANJAYV10 commited on
Commit
9258786
1 Parent(s): dc4b50d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -19
app.py CHANGED
@@ -1,5 +1,4 @@
1
  from fastapi import FastAPI
2
- from pydantic import BaseModel
3
  import torch
4
  import torch.nn as nn
5
  import torch
@@ -7,25 +6,8 @@ from torchvision import transforms
7
 
8
  from typing import Any, Type
9
 
10
- import pydantic
11
  import torch
12
 
13
- class TensorSchema(pydantic.BaseModel):
14
- shape: list[int]
15
- dtype: str
16
- requires_grad: bool
17
-
18
- def __get_pydantic_core_schema__(cls: Type[Any]) -> pydantic.schema.Schema:
19
- return pydantic.schema.Schema(
20
- type="object",
21
- properties={
22
- "shape": pydantic.schema.Schema(type="array", items=pydantic.schema.Schema(type="integer")),
23
- "dtype": pydantic.schema.Schema(type="string"),
24
- "requires_grad": pydantic.schema.Schema(type="boolean"),
25
- },
26
- required=["shape", "dtype", "requires_grad"],
27
- )
28
-
29
  class TorchTensor(torch.Tensor):
30
  pass
31
 
@@ -87,7 +69,7 @@ async def predict_endpoint(input: fastapi.File):
87
  predicted_class = prediction.argmax(1)
88
 
89
  # Return the prediction.
90
- return Prediction(prediction=predicted_class)
91
 
92
 
93
  if __name__ == "__main__":
 
1
  from fastapi import FastAPI
 
2
  import torch
3
  import torch.nn as nn
4
  import torch
 
6
 
7
  from typing import Any, Type
8
 
 
9
  import torch
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class TorchTensor(torch.Tensor):
12
  pass
13
 
 
69
  predicted_class = prediction.argmax(1)
70
 
71
  # Return the prediction.
72
+ return {"prediction": predicted_class.item()}
73
 
74
 
75
  if __name__ == "__main__":