Spaces:
Runtime error
Runtime error
AnishKumbhar
commited on
Commit
•
8da9e7a
1
Parent(s):
ddb1cdd
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,17 @@ class Prediction:
|
|
10 |
prediction: TorchTensor
|
11 |
|
12 |
app = fastapi.FastAPI(docs_url="/")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
model = Model()
|
14 |
model.load_state_dict(torch.load('model_state.pth', map_location='cpu'))
|
15 |
# Define a function to preprocess the input image
|
|
|
10 |
prediction: TorchTensor
|
11 |
|
12 |
app = fastapi.FastAPI(docs_url="/")
|
13 |
+
from transformers import ViTForImageClassification
|
14 |
+
|
15 |
+
# Define the number of classes in your custom dataset
|
16 |
+
num_classes = 20
|
17 |
+
|
18 |
+
# Initialize the ViTForImageClassification model
|
19 |
+
model = ViTForImageClassification.from_pretrained(
|
20 |
+
'google/vit-base-patch16-224-in21k',
|
21 |
+
num_labels=num_classes # Specify the number of classes
|
22 |
+
)
|
23 |
+
|
24 |
model = Model()
|
25 |
model.load_state_dict(torch.load('model_state.pth', map_location='cpu'))
|
26 |
# Define a function to preprocess the input image
|