Update script.py
Browse files
script.py
CHANGED
@@ -28,24 +28,29 @@ class PytorchWorker:
|
|
28 |
"cuda:0" if torch.cuda.is_available() else "cpu")
|
29 |
print(f"Using devide: {self.device}")
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
# if not torch.cuda.is_available():
|
36 |
# model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
|
37 |
# else:
|
38 |
# model_ckpt = torch.load(model_path)
|
39 |
|
40 |
-
model_ckpt = torch.load(model_path, map_location=self.device)
|
41 |
-
model.load_state_dict(model_ckpt)
|
42 |
|
43 |
return model.to(self.device).eval()
|
44 |
|
45 |
self.model = _load_model(model_name, model_path)
|
46 |
|
47 |
self.transforms = T.Compose([
|
48 |
-
T.Resize((
|
49 |
T.ToTensor(),
|
50 |
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
51 |
])
|
@@ -79,7 +84,6 @@ def make_submission(test_metadata, model_path, model_name, output_csv_path="./su
|
|
79 |
|
80 |
predictions.append(np.argmax(logits))
|
81 |
|
82 |
-
|
83 |
test_metadata["class_id"] = predictions
|
84 |
|
85 |
user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
|
@@ -92,7 +96,7 @@ if __name__ == "__main__":
|
|
92 |
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
|
93 |
zip_ref.extractall("/tmp/data")
|
94 |
|
95 |
-
MODEL_PATH = "
|
96 |
MODEL_NAME = "hf-hub:timm/eva02_large_patch14_clip_336.merged2b_ft_inat21"
|
97 |
|
98 |
metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv"
|
|
|
28 |
"cuda:0" if torch.cuda.is_available() else "cpu")
|
29 |
print(f"Using devide: {self.device}")
|
30 |
|
31 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
32 |
+
|
33 |
+
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
|
34 |
+
|
35 |
+
model.heads.head = torch.nn.Linear(model.heads.head.in_features,
|
36 |
+
number_of_categories)
|
37 |
+
|
38 |
+
model.load_state_dict(torch.load(model_path))
|
39 |
|
40 |
# if not torch.cuda.is_available():
|
41 |
# model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
|
42 |
# else:
|
43 |
# model_ckpt = torch.load(model_path)
|
44 |
|
45 |
+
# model_ckpt = torch.load(model_path, map_location=self.device)
|
46 |
+
# model.load_state_dict(model_ckpt)
|
47 |
|
48 |
return model.to(self.device).eval()
|
49 |
|
50 |
self.model = _load_model(model_name, model_path)
|
51 |
|
52 |
self.transforms = T.Compose([
|
53 |
+
T.Resize((384, 384)),
|
54 |
T.ToTensor(),
|
55 |
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
56 |
])
|
|
|
84 |
|
85 |
predictions.append(np.argmax(logits))
|
86 |
|
|
|
87 |
test_metadata["class_id"] = predictions
|
88 |
|
89 |
user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
|
|
|
96 |
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
|
97 |
zip_ref.extractall("/tmp/data")
|
98 |
|
99 |
+
MODEL_PATH = "model.pth"
|
100 |
MODEL_NAME = "hf-hub:timm/eva02_large_patch14_clip_336.merged2b_ft_inat21"
|
101 |
|
102 |
metadata_file_path = "./SnakeCLEF2024_TestMetadata.csv"
|