Sakalti commited on
Commit
5c2dd2f
·
verified ·
1 Parent(s): 2e17103

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -32
app.py CHANGED
@@ -13,24 +13,20 @@ from datetime import datetime
13
  # グローバル変数で検出された列を保存
14
  columns = []
15
 
16
- # ファイル読み込み関数
17
- def read_file(data_file):
18
  global columns
19
  try:
20
- # ファイルをロード
21
- file_extension = os.path.splitext(data_file.name)[1]
22
- if file_extension == '.csv':
23
- df = pd.read_csv(data_file.name)
24
- elif file_extension == '.json':
25
- df = pd.read_json(data_file.name)
26
- elif file_extension == '.xlsx':
27
- df = pd.read_excel(data_file.name)
28
- else:
29
- return "無効なファイル形式です。CSV, JSON, Excelファイルをアップロードしてください。"
30
-
31
- # 列を検出
32
  columns = df.columns.tolist()
33
- return columns
 
34
  except Exception as e:
35
  return f"エラーが発生しました: {str(e)}"
36
 
@@ -41,27 +37,24 @@ def validate_columns(prompt_col, description_col):
41
  return True
42
 
43
  # モデル訓練関数
44
- def train_model(data_file, model_name, epochs, batch_size, learning_rate, output_dir, prompt_col, description_col, hf_token):
45
  try:
46
  # 列の検証
47
  if not validate_columns(prompt_col, description_col):
48
  return "無効な列選択です。データセット内の列を確認してください。"
49
 
50
- # ファイルのロード
51
- file_extension = os.path.splitext(data_file.name)[1]
52
- if file_extension == '.csv':
53
- df = pd.read_csv(data_file.name)
54
- elif file_extension == '.json':
55
- df = pd.read_json(data_file.name)
56
- elif file_extension == '.xlsx':
57
- df = pd.read_excel(data_file.name)
58
 
59
  # データのプレビュー
60
  preview = df.head().to_string(index=False)
61
 
62
  # 訓練用テキストの準備
63
  df['text'] = df[prompt_col] + ': ' + df[description_col]
64
- dataset = Dataset.from_pandas(df[['text']])
65
 
66
  # GPT-2のトークナイザーとモデルを初期化
67
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
@@ -78,7 +71,7 @@ def train_model(data_file, model_name, epochs, batch_size, learning_rate, output
78
  tokens['labels'] = tokens['input_ids'].copy()
79
  return tokens
80
 
81
- tokenized_datasets = dataset.map(tokenize_function, batched=True)
82
 
83
  # 訓練のための設定
84
  training_args = TrainingArguments(
@@ -171,14 +164,14 @@ def generate_text(prompt, temperature, top_k, top_p, max_length, repetition_pena
171
  # UI設定
172
  with gr.Blocks() as ui:
173
  with gr.Row():
174
- data_file = gr.File(label="データファイル", file_types=[".csv", ".json", ".xlsx"])
175
  model_name = gr.Textbox(label="モデル名", value="gpt2")
176
  epochs = gr.Number(label="エポック数", value=3, minimum=1)
177
  batch_size = gr.Number(label="バッチサイズ", value=4, minimum=1)
178
  learning_rate = gr.Number(label="学習率", value=5e-5, minimum=1e-7, maximum=1e-2, step=1e-7)
179
  output_dir = gr.Textbox(label="出力ディレクトリ", value="./output")
180
- prompt_col = gr.Textbox(label="プロンプト列名", value="prompt")
181
- description_col = gr.Textbox(label="説明列名", value="description")
182
  hf_token = gr.Textbox(label="Hugging Face アクセストークン")
183
 
184
  with gr.Row():
@@ -186,8 +179,8 @@ with gr.Blocks() as ui:
186
  output = gr.Textbox(label="出力")
187
 
188
  validate_button.click(
189
- read_file,
190
- inputs=[data_file],
191
  outputs=[output]
192
  )
193
 
@@ -197,7 +190,7 @@ with gr.Blocks() as ui:
197
 
198
  train_button.click(
199
  train_model,
200
- inputs=[data_file, model_name, epochs, batch_size, learning_rate, output_dir, prompt_col, description_col, hf_token],
201
  outputs=[result_output]
202
  )
203
 
 
13
  # グローバル変数で検出された列を保存
14
  columns = []
15
 
16
+ # データセットをロードする関数
17
+ def load_data(dataset_name):
18
  global columns
19
  try:
20
+ # Hugging Faceのデータセットをロード
21
+ dataset = load_dataset(dataset_name)
22
+
23
+ # 最初のデータをプレビューとして表示
24
+ df = pd.DataFrame(dataset['train'])
25
+
26
+ # 列名を検出
 
 
 
 
 
27
  columns = df.columns.tolist()
28
+
29
+ return columns, df.head().to_string(index=False)
30
  except Exception as e:
31
  return f"エラーが発生しました: {str(e)}"
32
 
 
37
  return True
38
 
39
  # モデル訓練関数
40
+ def train_model(dataset_name, model_name, epochs, batch_size, learning_rate, output_dir, prompt_col, description_col, hf_token):
41
  try:
42
  # 列の検証
43
  if not validate_columns(prompt_col, description_col):
44
  return "無効な列選択です。データセット内の列を確認してください。"
45
 
46
+ # Hugging Faceのデータセットをロード
47
+ dataset = load_dataset(dataset_name)
48
+
49
+ # 訓練データを取得
50
+ df = pd.DataFrame(dataset['train'])
 
 
 
51
 
52
  # データのプレビュー
53
  preview = df.head().to_string(index=False)
54
 
55
  # 訓練用テキストの準備
56
  df['text'] = df[prompt_col] + ': ' + df[description_col]
57
+ train_dataset = Dataset.from_pandas(df[['text']])
58
 
59
  # GPT-2のトークナイザーとモデルを初期化
60
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
 
71
  tokens['labels'] = tokens['input_ids'].copy()
72
  return tokens
73
 
74
+ tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
75
 
76
  # 訓練のための設定
77
  training_args = TrainingArguments(
 
164
  # UI設定
165
  with gr.Blocks() as ui:
166
  with gr.Row():
167
+ dataset_name = gr.Textbox(label="データセット名", value="imdb") # ここにデータセット名を入力
168
  model_name = gr.Textbox(label="モデル名", value="gpt2")
169
  epochs = gr.Number(label="エポック数", value=3, minimum=1)
170
  batch_size = gr.Number(label="バッチサイズ", value=4, minimum=1)
171
  learning_rate = gr.Number(label="学習率", value=5e-5, minimum=1e-7, maximum=1e-2, step=1e-7)
172
  output_dir = gr.Textbox(label="出力ディレクトリ", value="./output")
173
+ prompt_col = gr.Textbox(label="プロンプト列名", value="text") # 例:IMDBのレビュー列名
174
+ description_col = gr.Textbox(label="説明列名", value="label") # 例:IMDBのラベル列名
175
  hf_token = gr.Textbox(label="Hugging Face アクセストークン")
176
 
177
  with gr.Row():
 
179
  output = gr.Textbox(label="出力")
180
 
181
  validate_button.click(
182
+ load_data,
183
+ inputs=[dataset_name],
184
  outputs=[output]
185
  )
186
 
 
190
 
191
  train_button.click(
192
  train_model,
193
+ inputs=[dataset_name, model_name, epochs, batch_size, learning_rate, output_dir, prompt_col, description_col, hf_token],
194
  outputs=[result_output]
195
  )
196