Image Classification
Transformers
Safetensors
cetaceanet
biology
biodiversity
custom_code
MalloryWittwerEPFL commited on
Commit
b3201aa
1 Parent(s): ad1895a

Fix broken model

Browse files
Files changed (5) hide show
  1. .gitignore +135 -0
  2. requirements.txt +8 -0
  3. test_model.py +12 -0
  4. test_model_before_push.py +30 -0
  5. train.py +3 -3
.gitignore ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ .pdm.toml
86
+
87
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
88
+ __pypackages__/
89
+
90
+ # Celery stuff
91
+ celerybeat-schedule
92
+ celerybeat.pid
93
+
94
+ # SageMath parsed files
95
+ *.sage.py
96
+
97
+ # Environments
98
+ .env
99
+ .venv
100
+ env/
101
+ venv/
102
+ ENV/
103
+ env.bak/
104
+ venv.bak/
105
+
106
+ # Spyder project settings
107
+ .spyderproject
108
+ .spyproject
109
+
110
+ # Rope project settings
111
+ .ropeproject
112
+
113
+ # mkdocs documentation
114
+ /site
115
+
116
+ # mypy
117
+ .mypy_cache/
118
+ .dmypy.json
119
+ dmypy.json
120
+
121
+ # Pyre type checker
122
+ .pyre/
123
+
124
+ # pytype static type analyzer
125
+ .pytype/
126
+
127
+ # Cython debug symbols
128
+ cython_debug/
129
+
130
+ notebooks/
131
+ old/
132
+ .vscode/
133
+ *.ckpt
134
+
135
+ *.jpg
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ albumentations==1.1.0
2
+ opencv-python-headless==4.5.5.64
3
+ optuna==3.0.4
4
+ pandas==1.3.5
5
+ pytorch-lightning<=1.5.10
6
+ timm==0.5.4
7
+ wandb==0.12.16
8
+ opencv-python
test_model.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+
4
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
5
+
6
+ from transformers import AutoModelForImageClassification
7
+
8
+ if __name__ == "__main__":
9
+ cetacean_classifier = AutoModelForImageClassification.from_pretrained("Saving-Willy/cetacean-classifier", trust_remote_code=True)
10
+ img = cv2.imread("tail.jpg")
11
+ out = cetacean_classifier(img)
12
+ print(out)
test_model_before_push.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script used to load a custom model and push it to HuggingFace.
3
+ Doc: https://huggingface.co/docs/transformers/custom_models#writing-a-custom-model
4
+ """
5
+ import os
6
+ import json
7
+ import cv2
8
+
9
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
10
+
11
+ from configuration_cetacean_classifier import (
12
+ CetaceanClassifierConfig,
13
+ )
14
+ from modeling_cetacean_classifier import (
15
+ CetaceanClassifierModelForImageClassification,
16
+ )
17
+
18
+ with open("original_model_config.json", "r") as file:
19
+ config= json.load(file)
20
+
21
+ cetacean_config = CetaceanClassifierConfig(**config)
22
+ cetacean_classifier = CetaceanClassifierModelForImageClassification(cetacean_config)
23
+
24
+ cetacean_classifier.model.load_from_checkpoint("last.ckpt")
25
+
26
+ img = cv2.imread("tail.jpg")
27
+
28
+ out = cetacean_classifier(img)
29
+ print(out)
30
+
train.py CHANGED
@@ -5,10 +5,10 @@ import timm
5
  import torch
6
  from pytorch_lightning import LightningDataModule, LightningModule, Trainer
7
 
8
- from .config import Config, load_config
9
  # from .dataset import WhaleDataset, load_df
10
- from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM
11
- from .utils import WarmupCosineLambda, map_dict, topk_average_precision
12
 
13
 
14
  class SphereClassifier(LightningModule):
 
5
  import torch
6
  from pytorch_lightning import LightningDataModule, LightningModule, Trainer
7
 
8
+ from config import Config, load_config
9
  # from .dataset import WhaleDataset, load_df
10
+ from metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM
11
+ from utils import WarmupCosineLambda, map_dict, topk_average_precision
12
 
13
 
14
  class SphereClassifier(LightningModule):