reach-vb HF staff commited on
Commit
ac97e5b
1 Parent(s): aa85862
Files changed (1) hide show
  1. app.py +56 -3
app.py CHANGED
@@ -11,6 +11,7 @@ from huggingface_hub import whoami
11
  from huggingface_hub import ModelCard
12
  from huggingface_hub import login
13
  from huggingface_hub import scan_cache_dir
 
14
 
15
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
16
 
@@ -18,7 +19,9 @@ from apscheduler.schedulers.background import BackgroundScheduler
18
 
19
  from textwrap import dedent
20
 
21
- from mlx_lm import convert
 
 
22
 
23
  HF_TOKEN = os.environ.get("HF_TOKEN")
24
 
@@ -33,6 +36,55 @@ def clear_cache():
33
 
34
  print("Cache has been cleared")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def process_model(model_id, q_method,oauth_token: gr.OAuthToken | None):
37
 
38
  if oauth_token.token is None:
@@ -41,11 +93,12 @@ def process_model(model_id, q_method,oauth_token: gr.OAuthToken | None):
41
  model_name = model_id.split('/')[-1]
42
  username = whoami(oauth_token.token)["name"]
43
 
44
- login(token=oauth_token.token, add_to_git_credential=True)
45
 
46
  try:
47
  upload_repo = username + "/" + model_name + "-mlx"
48
- convert(model_id, quantize=True, upload_repo=upload_repo)
 
49
  clear_cache()
50
  return (
51
  f'Find your repo <a href=\'{new_repo_url}\' target="_blank" style="text-decoration:underline">here</a>',
 
11
  from huggingface_hub import ModelCard
12
  from huggingface_hub import login
13
  from huggingface_hub import scan_cache_dir
14
+ from huggingface_hub import logging
15
 
16
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
17
 
 
19
 
20
  from textwrap import dedent
21
 
22
+ from mlx_lm import convert, __version__
23
+
24
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
25
 
26
  HF_TOKEN = os.environ.get("HF_TOKEN")
27
 
 
36
 
37
  print("Cache has been cleared")
38
 
39
+ def upload_to_hub(path, upload_repo, hf_path, token):
40
+
41
+ card = ModelCard.load(hf_path)
42
+ card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
43
+ card.data.base_model = hf_path
44
+ card.text = dedent(
45
+ f"""
46
+ # {upload_repo}
47
+
48
+ The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
49
+
50
+ ## Use with mlx
51
+
52
+ ```bash
53
+ pip install mlx-lm
54
+ ```
55
+
56
+ ```python
57
+ from mlx_lm import load, generate
58
+
59
+ model, tokenizer = load("{upload_repo}")
60
+
61
+ prompt="hello"
62
+
63
+ if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
64
+ messages = [{{"role": "user", "content": prompt}}]
65
+ prompt = tokenizer.apply_chat_template(
66
+ messages, tokenize=False, add_generation_prompt=True
67
+ )
68
+
69
+ response = generate(model, tokenizer, prompt=prompt, verbose=True)
70
+ ```
71
+ """
72
+ )
73
+ card.save(os.path.join(path, "README.md"))
74
+
75
+ logging.set_verbosity_info()
76
+
77
+ api = HfApi()
78
+ api.create_repo(repo_id=upload_repo, exist_ok=True)
79
+ api.upload_folder(
80
+ folder_path=path,
81
+ repo_id=upload_repo,
82
+ repo_type="model",
83
+ multi_commits=True,
84
+ multi_commits_verbose=True,
85
+ )
86
+ print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
87
+
88
  def process_model(model_id, q_method,oauth_token: gr.OAuthToken | None):
89
 
90
  if oauth_token.token is None:
 
93
  model_name = model_id.split('/')[-1]
94
  username = whoami(oauth_token.token)["name"]
95
 
96
+ # login(token=oauth_token.token, add_to_git_credential=True)
97
 
98
  try:
99
  upload_repo = username + "/" + model_name + "-mlx"
100
+ convert(model_id, quantize=True)
101
+ upload_repo(path="mlx_model", upload_repo=upload_repo, hf_path=repo_id, token=oauth_token.token)
102
  clear_cache()
103
  return (
104
  f'Find your repo <a href=\'{new_repo_url}\' target="_blank" style="text-decoration:underline">here</a>',