Update README.md
Browse files
README.md
CHANGED
@@ -76,10 +76,7 @@ quantization_config = BitsAndBytesConfig(
|
|
76 |
)
|
77 |
|
78 |
def load_model_and_tokenizer():
|
79 |
-
""
|
80 |
-
モデルとトークナイザーを並列でダウンロードし、チェックポイントをロードする
|
81 |
-
"""
|
82 |
-
model_id = "Chrom256/gemma-2-9b-it-lora_20241216_033631" # あなたのモデルのパス
|
83 |
base_model_id = "google/gemma-2-9b"
|
84 |
downloaded_components = {"model": None, "tokenizer": None}
|
85 |
download_lock = threading.Lock()
|
@@ -100,53 +97,47 @@ def load_model_and_tokenizer():
|
|
100 |
torch_dtype=torch.bfloat16,
|
101 |
attn_implementation="eager",
|
102 |
low_cpu_mem_usage=True,
|
103 |
-
token=HF_TOKEN
|
104 |
)
|
105 |
with download_lock:
|
106 |
downloaded_components["model"] = model
|
107 |
|
108 |
def download_tokenizer():
|
109 |
tokenizer = AutoTokenizer.from_pretrained(
|
110 |
-
model_id,
|
111 |
trust_remote_code=True,
|
112 |
-
token=HF_TOKEN
|
113 |
)
|
114 |
with download_lock:
|
115 |
downloaded_components["tokenizer"] = tokenizer
|
116 |
|
117 |
-
# GPUキャッシュをクリア
|
118 |
torch.cuda.empty_cache()
|
119 |
|
120 |
-
# ThreadPoolExecutor
|
121 |
with ThreadPoolExecutor(max_workers=2) as executor:
|
122 |
model_future = executor.submit(download_base_model)
|
123 |
tokenizer_future = executor.submit(download_tokenizer)
|
124 |
|
125 |
-
# 両方のダウンロードが完了するまで待機
|
126 |
model_future.result()
|
127 |
tokenizer_future.result()
|
128 |
|
129 |
model = downloaded_components["model"]
|
130 |
tokenizer = downloaded_components["tokenizer"]
|
131 |
|
132 |
-
# GPUキャッシュをクリア(チェックポイントロード前)
|
133 |
torch.cuda.empty_cache()
|
134 |
|
135 |
-
# チェックポイントのロード
|
136 |
try:
|
137 |
adapter_path = model_id
|
138 |
print(f"Loading adapter from {adapter_path}")
|
139 |
-
model.load_adapter(adapter_path, "default", token=HF_TOKEN)
|
140 |
print("Adapter loaded successfully")
|
141 |
except Exception as e:
|
142 |
print(f"Error loading adapter: {e}")
|
143 |
raise
|
144 |
|
145 |
-
# 最終設定
|
146 |
model.config.use_cache = True
|
147 |
model.eval()
|
148 |
|
149 |
-
# 最終的なGPUキャッシュのクリア
|
150 |
torch.cuda.empty_cache()
|
151 |
|
152 |
return model, tokenizer
|
@@ -168,10 +159,9 @@ def run_inference(model, tokenizer, tokenized_inputs, generation_config, batch_s
|
|
168 |
""" for item in batch
|
169 |
]
|
170 |
|
171 |
-
# 動的パディングを使用
|
172 |
inputs = tokenizer(
|
173 |
prompts,
|
174 |
-
padding=True,
|
175 |
truncation=True,
|
176 |
return_tensors="pt"
|
177 |
).to(model.device)
|
@@ -192,7 +182,6 @@ def run_inference(model, tokenizer, tokenized_inputs, generation_config, batch_s
|
|
192 |
elif 'model' in response:
|
193 |
response = response.split('model')[-1].strip()
|
194 |
|
195 |
-
# 後処理を追加
|
196 |
response = post_process_output(response)
|
197 |
|
198 |
results.append({
|
@@ -201,7 +190,6 @@ def run_inference(model, tokenizer, tokenized_inputs, generation_config, batch_s
|
|
201 |
"output": response
|
202 |
})
|
203 |
|
204 |
-
# バッチ処理後のメモリ解放
|
205 |
del outputs, inputs
|
206 |
torch.cuda.empty_cache()
|
207 |
|
|
|
76 |
)
|
77 |
|
78 |
def load_model_and_tokenizer():
|
79 |
+
model_id = "Chrom256/gemma-2-9b-it-lora_20241216_033631"
|
|
|
|
|
|
|
80 |
base_model_id = "google/gemma-2-9b"
|
81 |
downloaded_components = {"model": None, "tokenizer": None}
|
82 |
download_lock = threading.Lock()
|
|
|
97 |
torch_dtype=torch.bfloat16,
|
98 |
attn_implementation="eager",
|
99 |
low_cpu_mem_usage=True,
|
100 |
+
token=HF_TOKEN
|
101 |
)
|
102 |
with download_lock:
|
103 |
downloaded_components["model"] = model
|
104 |
|
105 |
def download_tokenizer():
|
106 |
tokenizer = AutoTokenizer.from_pretrained(
|
107 |
+
model_id,
|
108 |
trust_remote_code=True,
|
109 |
+
token=HF_TOKEN
|
110 |
)
|
111 |
with download_lock:
|
112 |
downloaded_components["tokenizer"] = tokenizer
|
113 |
|
|
|
114 |
torch.cuda.empty_cache()
|
115 |
|
116 |
+
# ThreadPoolExecutorを使用して並列ダウンロード
|
117 |
with ThreadPoolExecutor(max_workers=2) as executor:
|
118 |
model_future = executor.submit(download_base_model)
|
119 |
tokenizer_future = executor.submit(download_tokenizer)
|
120 |
|
|
|
121 |
model_future.result()
|
122 |
tokenizer_future.result()
|
123 |
|
124 |
model = downloaded_components["model"]
|
125 |
tokenizer = downloaded_components["tokenizer"]
|
126 |
|
|
|
127 |
torch.cuda.empty_cache()
|
128 |
|
|
|
129 |
try:
|
130 |
adapter_path = model_id
|
131 |
print(f"Loading adapter from {adapter_path}")
|
132 |
+
model.load_adapter(adapter_path, "default", token=HF_TOKEN)
|
133 |
print("Adapter loaded successfully")
|
134 |
except Exception as e:
|
135 |
print(f"Error loading adapter: {e}")
|
136 |
raise
|
137 |
|
|
|
138 |
model.config.use_cache = True
|
139 |
model.eval()
|
140 |
|
|
|
141 |
torch.cuda.empty_cache()
|
142 |
|
143 |
return model, tokenizer
|
|
|
159 |
""" for item in batch
|
160 |
]
|
161 |
|
|
|
162 |
inputs = tokenizer(
|
163 |
prompts,
|
164 |
+
padding=True,
|
165 |
truncation=True,
|
166 |
return_tensors="pt"
|
167 |
).to(model.device)
|
|
|
182 |
elif 'model' in response:
|
183 |
response = response.split('model')[-1].strip()
|
184 |
|
|
|
185 |
response = post_process_output(response)
|
186 |
|
187 |
results.append({
|
|
|
190 |
"output": response
|
191 |
})
|
192 |
|
|
|
193 |
del outputs, inputs
|
194 |
torch.cuda.empty_cache()
|
195 |
|