File size: 4,403 Bytes
2fc8bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7de5c67
2fc8bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7e6ea6
 
 
6cda51e
 
2fc8bbc
6cda51e
 
2fc8bbc
 
 
e7e6ea6
 
 
2fc8bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7e6ea6
 
 
 
 
 
 
 
 
 
cf500c1
2fc8bbc
 
 
 
e7e6ea6
2fc8bbc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import shutil
from pathlib import Path
from typing import Iterable, List

import gradio as gr
import kagglehub
from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
from huggingface_hub import HfApi

KAGGLE_JSON = os.environ.get("KAGGLE_JSON")
KAGGLE_JSON_PATH = Path("~/.kaggle/kaggle.json").expanduser().resolve()
if KAGGLE_JSON_PATH.exists():
    print(f"Found existing kaggle.json file at {KAGGLE_JSON_PATH}")
elif KAGGLE_JSON is not None:
    print(
        "KAGGLE_JSON is set as secret. Will be able to be authenticated when downloading files from Kaggle."
    )
    KAGGLE_JSON_PATH.parent.mkdir(parents=True, exist_ok=True)
    KAGGLE_JSON_PATH.write_text(KAGGLE_JSON)
else:
    print(
        f"No kaggle.json file found at {KAGGLE_JSON_PATH}. You will not be able to download private/gated files from Kaggle."
    )


MARKDOWN_DESCRIPTION = """
# Keggla-importer GUI

The fastest way to import a model from KaggleHub to the Hugging Face Hub 🔥

Specify a Kaggle handle and a Hugging Face Write Token to import a model from KaggleHub to the Hugging Face Hub.

To find the Kaggle handle from a web UI, click on the "download dropdown" and copy the handle from the code snippet.
Example: `"keras/gemma/keras/gemma_instruct_2b_en"`.
"""

if KAGGLE_JSON_PATH.exists():
    MARKDOWN_DESCRIPTION += """

**Note**: a `kaggle.json` file exists in the home directory. This means the Space will be able to download **SOME** private/gated files from Kaggle.
To access other models, please duplicate this Space to a private Space and set the `KAGGLE_JSON` environment variable with the content of the `kaggle.json`
you've downloaded from your Kaggle user account.

"""


def import_model(
    kaggle_model: str, repo_name: str, token: gr.OAuthToken | None
) -> Iterable[List[Log]]:
    runner = LogsViewRunner()

    if not kaggle_model:
        yield runner.log("Kaggle model is required.", level="ERROR")
        raise gr.Error("Kaggle model is required.")
    if not repo_name:
        repo_name = kaggle_model.split("/")[-1]
    if not token:
        yield runner.log("You must sign in with HF before proceeding.", level="ERROR")
        raise gr.Error("Authentication is required.")
    api = HfApi(token=token.token)

    yield runner.log(f"Creating HF repo {repo_name}")
    repo_url = api.create_repo(repo_name, exist_ok=True)
    yield runner.log(f"Created HF repo: {repo_url}")
    repo_id = repo_url.repo_id

    model_id = api.model_info(repo_id)
    if len(model_id.siblings) > 1:
        yield runner.log(
            f"Model repo {repo_id} is not empty. Please delete it or set a different repo name.",
            level="ERROR",
        )
        return

    yield runner.log(f"Downloading model {kaggle_model} from Kaggle.")
    yield from runner.run_python(kagglehub.model_download, handle=kaggle_model)
    if runner.exit_code != 0:
        yield runner.log("Failed to download model from Kaggle.", level="ERROR")
        api.delete_repo(repo_id=repo_id)
        return

    cache_path = kagglehub.model_download(kaggle_model)  # should be instant
    yield runner.log(f"Model successfully downloaded from Kaggle to {cache_path}.")

    yield runner.log(f"Uploading model to HF repo {repo_id}.")
    yield from runner.run_python(
        api.upload_folder, repo_id=repo_id, folder_path=cache_path
    )
    if runner.exit_code != 0:
        yield runner.log("Failed to upload model to HF repo.", level="ERROR")
        api.delete_repo(repo_id=repo_id)
        return

    yield runner.log(f"Model successfully uploaded to HF: {repo_url}.")
    yield runner.log(f"Deleting local cache from {cache_path}.")
    shutil.rmtree(cache_path)

    yield runner.log("Done!")


with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN_DESCRIPTION)

    with gr.Row():
        kaggle_model = gr.Textbox(
            lines=1,
            label="Kaggle Model*",
            placeholder="keras/codegemma/keras/code_gemma_7b_en",
        )
        repo_name = gr.Textbox(
            lines=1,
            label="Repo name",
            placeholder="Optional. Will infer from Kaggle Model if empty.",
        )
        gr.LoginButton(min_width=250)

    button = gr.Button("Import", variant="primary")
    logs = LogsView(label="Terminal output")

    button.click(fn=import_model, inputs=[kaggle_model, repo_name], outputs=[logs])


demo.queue(default_concurrency_limit=1).launch()