Spaces:
Sleeping
Sleeping
File size: 4,403 Bytes
2fc8bbc 7de5c67 2fc8bbc e7e6ea6 6cda51e 2fc8bbc 6cda51e 2fc8bbc e7e6ea6 2fc8bbc e7e6ea6 cf500c1 2fc8bbc e7e6ea6 2fc8bbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import shutil
from pathlib import Path
from typing import Iterable, List
import gradio as gr
import kagglehub
from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
from huggingface_hub import HfApi
KAGGLE_JSON = os.environ.get("KAGGLE_JSON")
KAGGLE_JSON_PATH = Path("~/.kaggle/kaggle.json").expanduser().resolve()
if KAGGLE_JSON_PATH.exists():
print(f"Found existing kaggle.json file at {KAGGLE_JSON_PATH}")
elif KAGGLE_JSON is not None:
print(
"KAGGLE_JSON is set as secret. Will be able to be authenticated when downloading files from Kaggle."
)
KAGGLE_JSON_PATH.parent.mkdir(parents=True, exist_ok=True)
KAGGLE_JSON_PATH.write_text(KAGGLE_JSON)
else:
print(
f"No kaggle.json file found at {KAGGLE_JSON_PATH}. You will not be able to download private/gated files from Kaggle."
)
MARKDOWN_DESCRIPTION = """
# Keggla-importer GUI
The fastest way to import a model from KaggleHub to the Hugging Face Hub 🔥
Specify a Kaggle handle and a Hugging Face Write Token to import a model from KaggleHub to the Hugging Face Hub.
To find the Kaggle handle from a web UI, click on the "download dropdown" and copy the handle from the code snippet.
Example: `"keras/gemma/keras/gemma_instruct_2b_en"`.
"""
if KAGGLE_JSON_PATH.exists():
MARKDOWN_DESCRIPTION += """
**Note**: a `kaggle.json` file exists in the home directory. This means the Space will be able to download **SOME** private/gated files from Kaggle.
To access other models, please duplicate this Space to a private Space and set the `KAGGLE_JSON` environment variable with the content of the `kaggle.json`
you've downloaded from your Kaggle user account.
"""
def import_model(
kaggle_model: str, repo_name: str, token: gr.OAuthToken | None
) -> Iterable[List[Log]]:
runner = LogsViewRunner()
if not kaggle_model:
yield runner.log("Kaggle model is required.", level="ERROR")
raise gr.Error("Kaggle model is required.")
if not repo_name:
repo_name = kaggle_model.split("/")[-1]
if not token:
yield runner.log("You must sign in with HF before proceeding.", level="ERROR")
raise gr.Error("Authentication is required.")
api = HfApi(token=token.token)
yield runner.log(f"Creating HF repo {repo_name}")
repo_url = api.create_repo(repo_name, exist_ok=True)
yield runner.log(f"Created HF repo: {repo_url}")
repo_id = repo_url.repo_id
model_id = api.model_info(repo_id)
if len(model_id.siblings) > 1:
yield runner.log(
f"Model repo {repo_id} is not empty. Please delete it or set a different repo name.",
level="ERROR",
)
return
yield runner.log(f"Downloading model {kaggle_model} from Kaggle.")
yield from runner.run_python(kagglehub.model_download, handle=kaggle_model)
if runner.exit_code != 0:
yield runner.log("Failed to download model from Kaggle.", level="ERROR")
api.delete_repo(repo_id=repo_id)
return
cache_path = kagglehub.model_download(kaggle_model) # should be instant
yield runner.log(f"Model successfully downloaded from Kaggle to {cache_path}.")
yield runner.log(f"Uploading model to HF repo {repo_id}.")
yield from runner.run_python(
api.upload_folder, repo_id=repo_id, folder_path=cache_path
)
if runner.exit_code != 0:
yield runner.log("Failed to upload model to HF repo.", level="ERROR")
api.delete_repo(repo_id=repo_id)
return
yield runner.log(f"Model successfully uploaded to HF: {repo_url}.")
yield runner.log(f"Deleting local cache from {cache_path}.")
shutil.rmtree(cache_path)
yield runner.log("Done!")
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN_DESCRIPTION)
with gr.Row():
kaggle_model = gr.Textbox(
lines=1,
label="Kaggle Model*",
placeholder="keras/codegemma/keras/code_gemma_7b_en",
)
repo_name = gr.Textbox(
lines=1,
label="Repo name",
placeholder="Optional. Will infer from Kaggle Model if empty.",
)
gr.LoginButton(min_width=250)
button = gr.Button("Import", variant="primary")
logs = LogsView(label="Terminal output")
button.click(fn=import_model, inputs=[kaggle_model, repo_name], outputs=[logs])
demo.queue(default_concurrency_limit=1).launch()
|