asoria HF staff commited on
Commit
e77bff1
1 Parent(s): 6785a2b

Adding basic SFT template

Browse files
app.py CHANGED
@@ -2,20 +2,22 @@ import gradio as gr
2
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
3
  import nbformat as nbf
4
  from huggingface_hub import HfApi
5
- from httpx import Client
6
  import logging
7
- import pandas as pd
8
  from utils.notebook_utils import (
9
  replace_wildcards,
10
  load_json_files_from_folder,
11
  )
 
12
  from dotenv import load_dotenv
13
  import os
14
  from nbconvert import HTMLExporter
15
  import uuid
 
16
 
17
  load_dotenv()
18
 
 
 
19
  HF_TOKEN = os.getenv("HF_TOKEN")
20
  assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
21
 
@@ -25,12 +27,6 @@ assert (
25
  ), "You need to set NOTEBOOKS_REPOSITORY in your environment variables"
26
 
27
 
28
- URL = "https://huggingface.co/spaces/asoria/auto-notebook-creator"
29
- BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
30
- HEADERS = {"Accept": "application/json", "Content-Type": "application/json"}
31
-
32
- client = Client(headers=HEADERS)
33
-
34
  logging.basicConfig(level=logging.INFO)
35
 
36
  # TODO: Validate notebook templates format
@@ -39,18 +35,6 @@ notebook_templates = load_json_files_from_folder(folder_path)
39
  logging.info(f"Available notebooks {notebook_templates.keys()}")
40
 
41
 
42
- def get_compatible_libraries(dataset: str):
43
- try:
44
- response = client.get(
45
- f"{BASE_DATASETS_SERVER_URL}/compatible-libraries?dataset={dataset}"
46
- )
47
- response.raise_for_status()
48
- return response.json()
49
- except Exception as e:
50
- logging.error(f"Error fetching compatible libraries: {e}")
51
- raise
52
-
53
-
54
  def create_notebook_file(cells, notebook_name):
55
  nb = nbf.v4.new_notebook()
56
  nb["cells"] = [
@@ -72,22 +56,6 @@ def create_notebook_file(cells, notebook_name):
72
  return html_data
73
 
74
 
75
- def get_first_rows_as_df(dataset: str, config: str, split: str, limit: int):
76
- try:
77
- resp = client.get(
78
- f"{BASE_DATASETS_SERVER_URL}/first-rows?dataset={dataset}&config={config}&split={split}"
79
- )
80
- resp.raise_for_status()
81
- content = resp.json()
82
- rows = content["rows"]
83
- rows = [row["row"] for row in rows]
84
- first_rows_df = pd.DataFrame.from_dict(rows).sample(frac=1).head(limit)
85
- return first_rows_df
86
- except Exception as e:
87
- logging.error(f"Error fetching first rows: {e}")
88
- raise
89
-
90
-
91
  def longest_string_column(df):
92
  longest_col = None
93
  max_length = 0
@@ -127,34 +95,62 @@ def generate_cells(dataset_id, notebook_title):
127
  cells = notebook_templates[notebook_title]["notebook_template"]
128
  notebook_type = notebook_templates[notebook_title]["notebook_type"]
129
  dataset_types = notebook_templates[notebook_title]["dataset_types"]
130
-
131
  try:
132
  libraries = get_compatible_libraries(dataset_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  except Exception as err:
134
  gr.Error("Unable to retrieve dataset info from HF Hub.")
135
  logging.error(f"Failed to fetch compatible libraries: {err}")
136
- return "", "## ❌ This dataset is not accessible from the Hub ❌"
137
-
138
- if not libraries:
139
- logging.error(f"Dataset not compatible with pandas library - not libraries")
140
- return "", "## ❌ This dataset is not compatible with pandas library ❌"
141
- pandas_library = next(
142
- (lib for lib in libraries.get("libraries", []) if lib["library"] == "pandas"),
143
- None,
144
- )
145
- if not pandas_library:
146
- logging.error("Dataset not compatible with pandas library - not pandas library")
147
- return "", "## ❌ This dataset is not compatible with pandas library ❌"
148
- first_config_loading_code = pandas_library["loading_codes"][0]
149
- first_code = first_config_loading_code["code"]
150
- first_config = first_config_loading_code["config_name"]
151
- first_split = list(first_config_loading_code["arguments"]["splits"].keys())[0]
152
- df = get_first_rows_as_df(dataset_id, first_config, first_split, 3)
153
 
154
  longest_col = longest_string_column(df)
155
  html_code = f"<iframe src='https://huggingface.co/datasets/{dataset_id}/embed/viewer' width='80%' height='560px'></iframe>"
156
- wildcards = ["{dataset_name}", "{first_code}", "{html_code}", "{longest_col}"]
157
- replacements = [dataset_id, first_code, html_code, longest_col]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  has_numeric_columns = len(df.select_dtypes(include=["number"]).columns) > 0
159
  has_categoric_columns = len(df.select_dtypes(include=["object"]).columns) > 0
160
 
@@ -196,8 +192,12 @@ css = """
196
 
197
  with gr.Blocks(css=css) as demo:
198
  gr.Markdown("# 🤖 Dataset notebook creator 🕵️")
199
- gr.Markdown(f"[![Notebooks: {len(notebook_templates)}](https://img.shields.io/badge/Notebooks-{len(notebook_templates)}-blue.svg)]({URL}/tree/main/notebooks)")
200
- gr.Markdown(f"[![Contribute a Notebook](https://img.shields.io/badge/Contribute%20a%20Notebook-8A2BE2)]({URL}/blob/main/CONTRIBUTING.md)")
 
 
 
 
201
  text_input = gr.Textbox(label="Suggested notebook type", visible=False)
202
 
203
  gr.Markdown("## 1. Select and preview a dataset from Huggingface Hub")
@@ -259,6 +259,8 @@ with gr.Blocks(css=css) as demo:
259
  outputs=[code_component, go_to_notebook],
260
  )
261
 
262
- gr.Markdown("🚧 Note: Some code may not be compatible with datasets that contain binary data or complex structures. 🚧")
 
 
263
 
264
  demo.launch()
 
2
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
3
  import nbformat as nbf
4
  from huggingface_hub import HfApi
 
5
  import logging
 
6
  from utils.notebook_utils import (
7
  replace_wildcards,
8
  load_json_files_from_folder,
9
  )
10
+ from utils.api_utils import get_compatible_libraries, get_first_rows, get_splits
11
  from dotenv import load_dotenv
12
  import os
13
  from nbconvert import HTMLExporter
14
  import uuid
15
+ import pandas as pd
16
 
17
  load_dotenv()
18
 
19
+ URL = "https://huggingface.co/spaces/asoria/auto-notebook-creator"
20
+
21
  HF_TOKEN = os.getenv("HF_TOKEN")
22
  assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
23
 
 
27
  ), "You need to set NOTEBOOKS_REPOSITORY in your environment variables"
28
 
29
 
 
 
 
 
 
 
30
  logging.basicConfig(level=logging.INFO)
31
 
32
  # TODO: Validate notebook templates format
 
35
  logging.info(f"Available notebooks {notebook_templates.keys()}")
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def create_notebook_file(cells, notebook_name):
39
  nb = nbf.v4.new_notebook()
40
  nb["cells"] = [
 
56
  return html_data
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def longest_string_column(df):
60
  longest_col = None
61
  max_length = 0
 
95
  cells = notebook_templates[notebook_title]["notebook_template"]
96
  notebook_type = notebook_templates[notebook_title]["notebook_type"]
97
  dataset_types = notebook_templates[notebook_title]["dataset_types"]
98
+ compatible_library = notebook_templates[notebook_title]["compatible_library"]
99
  try:
100
  libraries = get_compatible_libraries(dataset_id)
101
+ if not libraries:
102
+ logging.error(
103
+ f"Dataset not compatible with any loading library (pandas/datasets)"
104
+ )
105
+ return (
106
+ "",
107
+ "## ❌ This dataset is not compatible with pandas or datasets libraries ❌",
108
+ )
109
+
110
+ library_code = next(
111
+ (
112
+ lib
113
+ for lib in libraries.get("libraries", [])
114
+ if lib["library"] == compatible_library
115
+ ),
116
+ None,
117
+ )
118
+ if not library_code:
119
+ logging.error(f"Dataset not compatible with {compatible_library} library")
120
+ return (
121
+ "",
122
+ f"## ❌ This dataset is not compatible with '{compatible_library}' library ❌",
123
+ )
124
+ first_config_loading_code = library_code["loading_codes"][0]
125
+ first_code = first_config_loading_code["code"]
126
+ first_config = first_config_loading_code["config_name"]
127
+ first_split = get_splits(dataset_id, first_config)[0]["split"]
128
+ first_rows = get_first_rows(dataset_id, first_config, first_split)
129
  except Exception as err:
130
  gr.Error("Unable to retrieve dataset info from HF Hub.")
131
  logging.error(f"Failed to fetch compatible libraries: {err}")
132
+ return "", f"## ❌ This dataset is not accessible from the Hub {err}❌"
133
+
134
+ df = pd.DataFrame.from_dict(first_rows).sample(frac=1).head(3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  longest_col = longest_string_column(df)
137
  html_code = f"<iframe src='https://huggingface.co/datasets/{dataset_id}/embed/viewer' width='80%' height='560px'></iframe>"
138
+ wildcards = [
139
+ "{dataset_name}",
140
+ "{first_code}",
141
+ "{html_code}",
142
+ "{longest_col}",
143
+ "{first_config}",
144
+ "{first_split}",
145
+ ]
146
+ replacements = [
147
+ dataset_id,
148
+ first_code,
149
+ html_code,
150
+ longest_col,
151
+ first_config,
152
+ first_split,
153
+ ]
154
  has_numeric_columns = len(df.select_dtypes(include=["number"]).columns) > 0
155
  has_categoric_columns = len(df.select_dtypes(include=["object"]).columns) > 0
156
 
 
192
 
193
  with gr.Blocks(css=css) as demo:
194
  gr.Markdown("# 🤖 Dataset notebook creator 🕵️")
195
+ gr.Markdown(
196
+ f"[![Notebooks: {len(notebook_templates)}](https://img.shields.io/badge/Notebooks-{len(notebook_templates)}-blue.svg)]({URL}/tree/main/notebooks)"
197
+ )
198
+ gr.Markdown(
199
+ f"[![Contribute a Notebook](https://img.shields.io/badge/Contribute%20a%20Notebook-8A2BE2)]({URL}/blob/main/CONTRIBUTING.md)"
200
+ )
201
  text_input = gr.Textbox(label="Suggested notebook type", visible=False)
202
 
203
  gr.Markdown("## 1. Select and preview a dataset from Huggingface Hub")
 
259
  outputs=[code_component, go_to_notebook],
260
  )
261
 
262
+ gr.Markdown(
263
+ "🚧 Note: Some code may not be compatible with datasets that contain binary data or complex structures. 🚧"
264
+ )
265
 
266
  demo.launch()
notebooks/eda.json CHANGED
@@ -2,6 +2,7 @@
2
  "notebook_title": "Exploratory data analysis (EDA)",
3
  "notebook_type": "eda",
4
  "dataset_types": ["numeric", "text"],
 
5
  "notebook_template": [
6
  {
7
  "cell_type": "markdown",
 
2
  "notebook_title": "Exploratory data analysis (EDA)",
3
  "notebook_type": "eda",
4
  "dataset_types": ["numeric", "text"],
5
+ "compatible_library": "pandas",
6
  "notebook_template": [
7
  {
8
  "cell_type": "markdown",
notebooks/embeddings.json CHANGED
@@ -2,6 +2,7 @@
2
  "notebook_title": "Text Embeddings",
3
  "notebook_type": "embeddings",
4
  "dataset_types": ["text"],
 
5
  "notebook_template": [
6
  {
7
  "cell_type": "markdown",
 
2
  "notebook_title": "Text Embeddings",
3
  "notebook_type": "embeddings",
4
  "dataset_types": ["text"],
5
+ "compatible_library": "pandas",
6
  "notebook_template": [
7
  {
8
  "cell_type": "markdown",
notebooks/rag.json CHANGED
@@ -2,6 +2,7 @@
2
  "notebook_title": "Retrieval-augmented generation (RAG)",
3
  "notebook_type": "rag",
4
  "dataset_types": ["text"],
 
5
  "notebook_template": [
6
  {
7
  "cell_type": "markdown",
 
2
  "notebook_title": "Retrieval-augmented generation (RAG)",
3
  "notebook_type": "rag",
4
  "dataset_types": ["text"],
5
+ "compatible_library": "pandas",
6
  "notebook_template": [
7
  {
8
  "cell_type": "markdown",
notebooks/sft.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "notebook_title": "Supervised fine-tuning (SFT)",
3
+ "notebook_type": "sft",
4
+ "dataset_types": ["text"],
5
+ "compatible_library": "datasets",
6
+ "notebook_template": [
7
+ {
8
+ "cell_type": "markdown",
9
+ "source": "---\n# **Supervised fine-tuning Notebook for {dataset_name} dataset**\n---"
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "source": "## 1. Setup necessary libraries and load the dataset"
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "source": "# Install and import necessary libraries.\n!pip install trl datasets transformers bitsandbytes"
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "source": "from datasets import load_dataset\nfrom trl import SFTTrainer\nfrom transformers import TrainingArguments"
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "source": "# Load the dataset\ndataset = load_dataset('{dataset_name}', name='{first_config}', split='{first_split}')\ndataset"
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "source": "# Specify the column name that will be used for training\ndataset_text_field = '{longest_col}'"
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "source": "## 2. Configure SFT trainer"
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "source": "model_name = 'facebook/opt-350m'\noutput_model_name = f'{model_name}-{dataset_name}'.replace('/', '-')\n\ntrainer = SFTTrainer(\n model = model_name,\n train_dataset=dataset,\n dataset_text_field=dataset_text_field,\n max_seq_length=512,\n args=TrainingArguments(\n per_device_train_batch_size = 1, #Batch size per GPU for training\n gradient_accumulation_steps = 4,\n max_steps = 100, #Total number of training steps.(Overrides epochs)\n learning_rate = 2e-4,\n fp16 = True,\n logging_steps=20,\n output_dir = output_model_name,\n optim = 'paged_adamw_8bit' #Optimizer to use\n )\n)"
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "source": "# Start training\ntrainer.train()"
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "source": "## 3. Push model to hub"
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "source": "# Authenticate to the Hugging Face Hub\nfrom huggingface_hub import notebook_login\nnotebook_login()"
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "source": "# Push the model to Hugging Face Hub\ntrainer.push_to_hub()"
54
+ }
55
+ ]
56
+ }
utils/api_utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from httpx import Client
2
+
3
+ BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
4
+ HEADERS = {"Accept": "application/json", "Content-Type": "application/json"}
5
+
6
+ client = Client(headers=HEADERS)
7
+
8
+
9
+ def get_compatible_libraries(dataset: str):
10
+ response = client.get(
11
+ f"{BASE_DATASETS_SERVER_URL}/compatible-libraries?dataset={dataset}"
12
+ )
13
+ response.raise_for_status()
14
+ return response.json()
15
+
16
+
17
+ def get_first_rows(dataset: str, config: str, split: str):
18
+ resp = client.get(
19
+ f"{BASE_DATASETS_SERVER_URL}/first-rows?dataset={dataset}&config={config}&split={split}"
20
+ )
21
+ resp.raise_for_status()
22
+ content = resp.json()
23
+ rows = content["rows"]
24
+ return [row["row"] for row in rows]
25
+
26
+
27
+ def get_splits(dataset: str, config: str):
28
+ resp = client.get(
29
+ f"{BASE_DATASETS_SERVER_URL}/splits?dataset={dataset}&config={config}"
30
+ )
31
+ resp.raise_for_status()
32
+ content = resp.json()
33
+ return content["splits"]