hysts HF staff commited on
Commit
f57e36f
·
1 Parent(s): 9bdd97c
Files changed (1) hide show
  1. app.py +33 -38
app.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  import pathlib
9
  import subprocess
10
  import sys
 
11
 
12
  # workaround for https://github.com/gradio-app/gradio/issues/483
13
  command = 'pip install -U gradio==2.7.0'
@@ -26,8 +27,8 @@ from _util.twodee_v0 import I as ImageWrapper
26
  TOKEN = os.environ['TOKEN']
27
 
28
  MODEL_REPO = 'hysts/bizarre-pose-estimator-models'
29
- MODEL_PATH = 'tagger.pth'
30
- LABEL_PATH = 'tags.txt'
31
 
32
 
33
  def parse_args() -> argparse.Namespace:
@@ -47,21 +48,38 @@ def parse_args() -> argparse.Namespace:
47
  return parser.parse_args()
48
 
49
 
50
- def download_sample_images() -> list[pathlib.Path]:
51
- image_dir = pathlib.Path('samples')
52
- image_dir.mkdir(exist_ok=True)
53
-
54
- dataset_repo = 'hysts/sample-images-TADNE'
55
- n_images = 36
56
- paths = []
57
- for index in range(n_images):
58
  path = huggingface_hub.hf_hub_download(dataset_repo,
59
- f'{index:02d}.jpg',
60
  repo_type='dataset',
61
- cache_dir=image_dir.as_posix(),
62
  use_auth_token=TOKEN)
63
- paths.append(pathlib.Path(path))
64
- return paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  @torch.inference_mode()
@@ -84,36 +102,13 @@ def predict(image: PIL.Image.Image, score_threshold: float,
84
  return res
85
 
86
 
87
- def load_model(device: torch.device) -> torch.nn.Module:
88
- model_path = huggingface_hub.hf_hub_download(MODEL_REPO,
89
- MODEL_PATH,
90
- use_auth_token=TOKEN)
91
- state_dict = torch.load(model_path)
92
-
93
- model = torchvision.models.resnet50(num_classes=1062)
94
- model.load_state_dict(state_dict)
95
- model.to(device)
96
- model.eval()
97
-
98
- return model
99
-
100
-
101
- def load_labels() -> list[str]:
102
- label_path = huggingface_hub.hf_hub_download(MODEL_REPO,
103
- LABEL_PATH,
104
- use_auth_token=TOKEN)
105
- with open(label_path) as f:
106
- labels = [line.strip() for line in f.readlines()]
107
- return labels
108
-
109
-
110
  def main():
111
  gr.close_all()
112
 
113
  args = parse_args()
114
  device = torch.device(args.device)
115
 
116
- image_paths = download_sample_images()
117
  examples = [[path.as_posix(), args.score_threshold]
118
  for path in image_paths]
119
 
 
8
  import pathlib
9
  import subprocess
10
  import sys
11
+ import tarfile
12
 
13
  # workaround for https://github.com/gradio-app/gradio/issues/483
14
  command = 'pip install -U gradio==2.7.0'
 
27
  TOKEN = os.environ['TOKEN']
28
 
29
  MODEL_REPO = 'hysts/bizarre-pose-estimator-models'
30
+ MODEL_FILENAME = 'tagger.pth'
31
+ LABEL_FILENAME = 'tags.txt'
32
 
33
 
34
  def parse_args() -> argparse.Namespace:
 
48
  return parser.parse_args()
49
 
50
 
51
+ def load_sample_image_paths() -> list[pathlib.Path]:
52
+ image_dir = pathlib.Path('images')
53
+ if not image_dir.exists():
54
+ dataset_repo = 'hysts/sample-images-TADNE'
 
 
 
 
55
  path = huggingface_hub.hf_hub_download(dataset_repo,
56
+ 'images.tar.gz',
57
  repo_type='dataset',
 
58
  use_auth_token=TOKEN)
59
+ with tarfile.open(path) as f:
60
+ f.extractall()
61
+ return sorted(image_dir.glob('*'))
62
+
63
+
64
+ def load_model(device: torch.device) -> torch.nn.Module:
65
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
66
+ MODEL_FILENAME,
67
+ use_auth_token=TOKEN)
68
+ state_dict = torch.load(path)
69
+ model = torchvision.models.resnet50(num_classes=1062)
70
+ model.load_state_dict(state_dict)
71
+ model.to(device)
72
+ model.eval()
73
+ return model
74
+
75
+
76
+ def load_labels() -> list[str]:
77
+ label_path = huggingface_hub.hf_hub_download(MODEL_REPO,
78
+ LABEL_FILENAME,
79
+ use_auth_token=TOKEN)
80
+ with open(label_path) as f:
81
+ labels = [line.strip() for line in f.readlines()]
82
+ return labels
83
 
84
 
85
  @torch.inference_mode()
 
102
  return res
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def main():
106
  gr.close_all()
107
 
108
  args = parse_args()
109
  device = torch.device(args.device)
110
 
111
+ image_paths = load_sample_image_paths()
112
  examples = [[path.as_posix(), args.score_threshold]
113
  for path in image_paths]
114