fix training loop for push to hub
Browse files- training_loop.py +7 -5
training_loop.py
CHANGED
@@ -74,12 +74,14 @@ def main(
|
|
74 |
|
75 |
if push_to_hub:
|
76 |
config = AutoConfig.from_pretrained(f"nvidia/mit-b{model_flavor}")
|
77 |
-
config.
|
|
|
|
|
|
|
|
|
78 |
checkpoint_path = checkpoint_callback.best_model_filepath
|
79 |
-
model = SegformerForSemanticSegmentation.from_pretrained(
|
80 |
-
|
81 |
-
)
|
82 |
-
model.push_to_hub("segformer-sidewalk", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")
|
83 |
|
84 |
|
85 |
if __name__ == "__main__":
|
|
|
74 |
|
75 |
if push_to_hub:
|
76 |
config = AutoConfig.from_pretrained(f"nvidia/mit-b{model_flavor}")
|
77 |
+
config.num_labels = num_labels
|
78 |
+
config.id2label = id2label
|
79 |
+
config.label2id = {v: k for k, v in id2label_file.items()}
|
80 |
+
config.push_to_hub(".", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")
|
81 |
+
|
82 |
checkpoint_path = checkpoint_callback.best_model_filepath
|
83 |
+
model = SegformerForSemanticSegmentation.from_pretrained(checkpoint_path, config=config,)
|
84 |
+
model.push_to_hub(".", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")
|
|
|
|
|
85 |
|
86 |
|
87 |
if __name__ == "__main__":
|