kaggle_importer / app.py
Wauplin's picture
Wauplin HF staff
more accesss
cf500c1 verified
raw
history blame
No virus
4.4 kB
import os
import shutil
from pathlib import Path
from typing import Iterable, List
import gradio as gr
import kagglehub
from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
from huggingface_hub import HfApi
KAGGLE_JSON = os.environ.get("KAGGLE_JSON")
KAGGLE_JSON_PATH = Path("~/.kaggle/kaggle.json").expanduser().resolve()
if KAGGLE_JSON_PATH.exists():
print(f"Found existing kaggle.json file at {KAGGLE_JSON_PATH}")
elif KAGGLE_JSON is not None:
print(
"KAGGLE_JSON is set as secret. Will be able to be authenticated when downloading files from Kaggle."
)
KAGGLE_JSON_PATH.parent.mkdir(parents=True, exist_ok=True)
KAGGLE_JSON_PATH.write_text(KAGGLE_JSON)
else:
print(
f"No kaggle.json file found at {KAGGLE_JSON_PATH}. You will not be able to download private/gated files from Kaggle."
)
MARKDOWN_DESCRIPTION = """
# Keggla-importer GUI
The fastest way to import a model from KaggleHub to the Hugging Face Hub πŸ”₯
Specify a Kaggle handle and a Hugging Face Write Token to import a model from KaggleHub to the Hugging Face Hub.
To find the Kaggle handle from a web UI, click on the "download dropdown" and copy the handle from the code snippet.
Example: `"keras/gemma/keras/gemma_instruct_2b_en"`.
"""
if KAGGLE_JSON_PATH.exists():
MARKDOWN_DESCRIPTION += """
**Note**: a `kaggle.json` file exists in the home directory. This means the Space will be able to download **SOME** private/gated files from Kaggle.
To access other models, please duplicate this Space to a private Space and set the `KAGGLE_JSON` environment variable with the content of the `kaggle.json`
you've downloaded from your Kaggle user account.
"""
def import_model(
kaggle_model: str, repo_name: str, token: gr.OAuthToken | None
) -> Iterable[List[Log]]:
runner = LogsViewRunner()
if not kaggle_model:
yield runner.log("Kaggle model is required.", level="ERROR")
raise gr.Error("Kaggle model is required.")
if not repo_name:
repo_name = kaggle_model.split("/")[-1]
if not token:
yield runner.log("You must sign in with HF before proceeding.", level="ERROR")
raise gr.Error("Authentication is required.")
api = HfApi(token=token.token)
yield runner.log(f"Creating HF repo {repo_name}")
repo_url = api.create_repo(repo_name, exist_ok=True)
yield runner.log(f"Created HF repo: {repo_url}")
repo_id = repo_url.repo_id
model_id = api.model_info(repo_id)
if len(model_id.siblings) > 1:
yield runner.log(
f"Model repo {repo_id} is not empty. Please delete it or set a different repo name.",
level="ERROR",
)
return
yield runner.log(f"Downloading model {kaggle_model} from Kaggle.")
yield from runner.run_python(kagglehub.model_download, handle=kaggle_model)
if runner.exit_code != 0:
yield runner.log("Failed to download model from Kaggle.", level="ERROR")
api.delete_repo(repo_id=repo_id)
return
cache_path = kagglehub.model_download(kaggle_model) # should be instant
yield runner.log(f"Model successfully downloaded from Kaggle to {cache_path}.")
yield runner.log(f"Uploading model to HF repo {repo_id}.")
yield from runner.run_python(
api.upload_folder, repo_id=repo_id, folder_path=cache_path
)
if runner.exit_code != 0:
yield runner.log("Failed to upload model to HF repo.", level="ERROR")
api.delete_repo(repo_id=repo_id)
return
yield runner.log(f"Model successfully uploaded to HF: {repo_url}.")
yield runner.log(f"Deleting local cache from {cache_path}.")
shutil.rmtree(cache_path)
yield runner.log("Done!")
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN_DESCRIPTION)
with gr.Row():
kaggle_model = gr.Textbox(
lines=1,
label="Kaggle Model*",
placeholder="keras/codegemma/keras/code_gemma_7b_en",
)
repo_name = gr.Textbox(
lines=1,
label="Repo name",
placeholder="Optional. Will infer from Kaggle Model if empty.",
)
gr.LoginButton(min_width=250)
button = gr.Button("Import", variant="primary")
logs = LogsView(label="Terminal output")
button.click(fn=import_model, inputs=[kaggle_model, repo_name], outputs=[logs])
demo.queue(default_concurrency_limit=1).launch()