AnishKumbhar commited on
Commit
8da9e7a
1 Parent(s): ddb1cdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -0
app.py CHANGED
@@ -10,6 +10,17 @@ class Prediction:
10
  prediction: TorchTensor
11
 
12
  app = fastapi.FastAPI(docs_url="/")
 
 
 
 
 
 
 
 
 
 
 
13
  model = Model()
14
  model.load_state_dict(torch.load('model_state.pth', map_location='cpu'))
15
  # Define a function to preprocess the input image
 
10
  prediction: TorchTensor
11
 
12
  app = fastapi.FastAPI(docs_url="/")
13
+ from transformers import ViTForImageClassification
14
+
15
+ # Define the number of classes in your custom dataset
16
+ num_classes = 20
17
+
18
+ # Initialize the ViTForImageClassification model
19
+ model = ViTForImageClassification.from_pretrained(
20
+ 'google/vit-base-patch16-224-in21k',
21
+ num_labels=num_classes # Specify the number of classes
22
+ )
23
+
24
  model = Model()
25
  model.load_state_dict(torch.load('model_state.pth', map_location='cpu'))
26
  # Define a function to preprocess the input image