farrosalferro24 commited on
Commit
79ff326
·
verified ·
1 Parent(s): 0b62128

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +12 -8
script.py CHANGED
@@ -28,24 +28,29 @@ class PytorchWorker:
28
  "cuda:0" if torch.cuda.is_available() else "cpu")
29
  print(f"Using devide: {self.device}")
30
 
31
- model = timm.create_model(model_name,
32
- num_classes=number_of_categories,
33
- pretrained=False)
 
 
 
 
 
34
 
35
  # if not torch.cuda.is_available():
36
  # model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
37
  # else:
38
  # model_ckpt = torch.load(model_path)
39
 
40
- model_ckpt = torch.load(model_path, map_location=self.device)
41
- model.load_state_dict(model_ckpt)
42
 
43
  return model.to(self.device).eval()
44
 
45
  self.model = _load_model(model_name, model_path)
46
 
47
  self.transforms = T.Compose([
48
- T.Resize((336, 336)),
49
  T.ToTensor(),
50
  T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
51
  ])
@@ -79,7 +84,6 @@ def make_submission(test_metadata, model_path, model_name, output_csv_path="./su
79
 
80
  predictions.append(np.argmax(logits))
81
 
82
-
83
  test_metadata["class_id"] = predictions
84
 
85
  user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
@@ -92,7 +96,7 @@ if __name__ == "__main__":
92
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
93
  zip_ref.extractall("/tmp/data")
94
 
95
- MODEL_PATH = "pytorch_model.bin"
96
  MODEL_NAME = "hf-hub:timm/eva02_large_patch14_clip_336.merged2b_ft_inat21"
97
 
98
  metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv"
 
28
  "cuda:0" if torch.cuda.is_available() else "cpu")
29
  print(f"Using devide: {self.device}")
30
 
31
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
32
+
33
+ model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
34
+
35
+ model.heads.head = torch.nn.Linear(model.heads.head.in_features,
36
+ number_of_categories)
37
+
38
+ model.load_state_dict(torch.load(model_path))
39
 
40
  # if not torch.cuda.is_available():
41
  # model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
42
  # else:
43
  # model_ckpt = torch.load(model_path)
44
 
45
+ # model_ckpt = torch.load(model_path, map_location=self.device)
46
+ # model.load_state_dict(model_ckpt)
47
 
48
  return model.to(self.device).eval()
49
 
50
  self.model = _load_model(model_name, model_path)
51
 
52
  self.transforms = T.Compose([
53
+ T.Resize((384, 384)),
54
  T.ToTensor(),
55
  T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
56
  ])
 
84
 
85
  predictions.append(np.argmax(logits))
86
 
 
87
  test_metadata["class_id"] = predictions
88
 
89
  user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
 
96
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
97
  zip_ref.extractall("/tmp/data")
98
 
99
+ MODEL_PATH = "model.pth"
100
  MODEL_NAME = "hf-hub:timm/eva02_large_patch14_clip_336.merged2b_ft_inat21"
101
 
102
  metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv"