julien-c HF staff commited on
Commit
9c99c64
1 Parent(s): 66662af

Tweaks! working

Browse files
Files changed (2) hide show
  1. app.py +4 -2
  2. convert.py +11 -13
app.py CHANGED
@@ -11,13 +11,15 @@ def run(token: str, model_id: str) -> str:
11
  Please fill a token and model_id.
12
  """
13
  try:
14
- pr_url = convert(token=token, model_id=model_id)
 
 
15
  return f"""
16
  ### Success 🔥
17
 
18
  Yay! This model was successfully converted and a PR was open using your token, here:
19
 
20
- {pr_url}
21
  """
22
  except Exception as e:
23
  return f"""
 
11
  Please fill a token and model_id.
12
  """
13
  try:
14
+ # TODO(Run this in a separate directory otherwise max_concurrency = 1...)
15
+ # as filename of conversion is fixed.
16
+ commit_info = convert(token=token, model_id=model_id)
17
  return f"""
18
  ### Success 🔥
19
 
20
  Yay! This model was successfully converted and a PR was open using your token, here:
21
 
22
+ [{commit_info.pr_url}]({commit_info.pr_url})
23
  """
24
  except Exception as e:
25
  return f"""
convert.py CHANGED
@@ -4,7 +4,7 @@ import os
4
 
5
  import torch
6
 
7
- from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download
8
  from safetensors.torch import save_file
9
 
10
 
@@ -14,7 +14,7 @@ def rename(pt_filename) -> str:
14
  return local
15
 
16
 
17
- def convert_multi(model_id) -> str:
18
  local_filenames = []
19
  try:
20
  filename = hf_hub_download(
@@ -39,7 +39,6 @@ def convert_multi(model_id) -> str:
39
  json.dump(newdata, f)
40
  local_filenames.append(index)
41
 
42
- api = HfApi()
43
  operations = [
44
  CommitOperationAdd(path_in_repo=local, path_or_fileobj=local)
45
  for local in local_filenames
@@ -55,26 +54,25 @@ def convert_multi(model_id) -> str:
55
  os.remove(local)
56
 
57
 
58
- def convert_single(model_id) -> str:
59
  local = "model.safetensors"
60
  try:
61
  filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
62
  loaded = torch.load(filename)
63
  save_file(loaded, local, metadata={"format": "pt"})
64
 
65
- api = HfApi()
66
-
67
- return api.upload_file(
68
- path_or_fileobj=local,
69
- create_pr=True,
70
- path_in_repo=local,
71
  repo_id=model_id,
 
 
 
72
  )
73
  finally:
74
  os.remove(local)
75
 
76
 
77
- def convert(token: str, model_id: str) -> str:
78
  """
79
  returns url to the PR
80
  """
@@ -82,9 +80,9 @@ def convert(token: str, model_id: str) -> str:
82
  info = api.model_info(model_id)
83
  filenames = set(s.rfilename for s in info.siblings)
84
  if "pytorch_model.bin" in filenames:
85
- return convert_single(model_id)
86
  elif "pytorch_model.bin.index.json" in filenames:
87
- return convert_multi(model_id)
88
  raise ValueError("repo does not seem to have a pytorch_model in it")
89
 
90
 
 
4
 
5
  import torch
6
 
7
+ from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, CommitInfo
8
  from safetensors.torch import save_file
9
 
10
 
 
14
  return local
15
 
16
 
17
+ def convert_multi(api: HfApi, model_id) -> CommitInfo:
18
  local_filenames = []
19
  try:
20
  filename = hf_hub_download(
 
39
  json.dump(newdata, f)
40
  local_filenames.append(index)
41
 
 
42
  operations = [
43
  CommitOperationAdd(path_in_repo=local, path_or_fileobj=local)
44
  for local in local_filenames
 
54
  os.remove(local)
55
 
56
 
57
+ def convert_single(api: HfApi, model_id) -> CommitInfo:
58
  local = "model.safetensors"
59
  try:
60
  filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
61
  loaded = torch.load(filename)
62
  save_file(loaded, local, metadata={"format": "pt"})
63
 
64
+ operations = [CommitOperationAdd(path_in_repo=local, path_or_fileobj=local)]
65
+ return api.create_commit(
 
 
 
 
66
  repo_id=model_id,
67
+ operations=operations,
68
+ commit_message="Adding `safetensors` variant of this model",
69
+ create_pr=True,
70
  )
71
  finally:
72
  os.remove(local)
73
 
74
 
75
+ def convert(token: str, model_id: str) -> CommitInfo:
76
  """
77
  returns url to the PR
78
  """
 
80
  info = api.model_info(model_id)
81
  filenames = set(s.rfilename for s in info.siblings)
82
  if "pytorch_model.bin" in filenames:
83
+ return convert_single(api, model_id)
84
  elif "pytorch_model.bin.index.json" in filenames:
85
+ return convert_multi(api, model_id)
86
  raise ValueError("repo does not seem to have a pytorch_model in it")
87
 
88