reach-vb HF staff commited on
Commit
3e8df62
1 Parent(s): 04543f3

Upload prompt_creation_expresso.py (#3)

Browse files

- Upload prompt_creation_expresso.py (6276b6b38fde54a3635a255238fcf010141e58ba)

Files changed (1) hide show
  1. prompt_creation_expresso.py +521 -0
prompt_creation_expresso.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ import shutil
6
+ import sys
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from accelerate import Accelerator, skip_first_batches
14
+ from accelerate.logging import get_logger
15
+ from datasets import DatasetDict, load_dataset
16
+ from torch.utils.data import DataLoader
17
+ from tqdm import tqdm
18
+ from transformers import (
19
+ AutoModelForCausalLM,
20
+ AutoTokenizer,
21
+ BitsAndBytesConfig,
22
+ HfArgumentParser,
23
+ )
24
+
25
+
26
+ logger = get_logger(__name__, log_level="INFO")
27
+
28
+
29
+ @dataclass
30
+ class ModelArguments:
31
+ """
32
+ Arguments pertaining to what data we are going to input our model for training and eval.
33
+ """
34
+
35
+ model_name_or_path: str = field(
36
+ metadata={"help": "The name of the model to use (via the transformers library) for the prompt annotation."},
37
+ )
38
+ per_device_eval_batch_size: int = field(
39
+ metadata={"help": "The per-device batch size to use for inference."},
40
+ )
41
+ model_variant: str = field(
42
+ default=None,
43
+ metadata={"help": "If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. "},
44
+ )
45
+ model_revision: str = field(
46
+ default="main",
47
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
48
+ )
49
+ cache_dir: Optional[str] = field(
50
+ default=None,
51
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
52
+ )
53
+ torch_dtype: Optional[str] = field(
54
+ default="float16",
55
+ metadata={
56
+ "help": (
57
+ "Floating-point format in which the model weights should be initialized"
58
+ " and the computations run. Choose one of `[float32, float16, bfloat16]`."
59
+ )
60
+ },
61
+ )
62
+ attn_implementation: Optional[str] = field(
63
+ default="sdpa",
64
+ metadata={"help": "Which attn type to use: ['eager', 'sdpa', 'flash_attention_2']"},
65
+ )
66
+ load_in_8bit: Optional[bool] = field(
67
+ default=False, metadata={"help": "Whether to use 8-bit precision for inference."}
68
+ )
69
+ load_in_4bit: Optional[bool] = field(
70
+ default=False, metadata={"help": "Whether to use 4-bit precision for inference."}
71
+ )
72
+ bnb_4bit_quant_type: Optional[str] = field(
73
+ default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
74
+ )
75
+ use_bnb_nested_quant: Optional[bool] = field(default=False, metadata={"help": "use nested quantization"})
76
+ trust_remote_code: Optional[bool] = field(
77
+ default=False,
78
+ metadata={
79
+ "help": (
80
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
81
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
82
+ "execute code present on the Hub on your local machine."
83
+ )
84
+ },
85
+ )
86
+ use_fast_tokenizer: Optional[bool] = field(
87
+ default=True, metadata={"help": "Use fast tokenizer for encoding/decoding input ids"}
88
+ )
89
+ token: Optional[bool] = field(
90
+ default=True,
91
+ metadata={
92
+ "help": "Whether or not to use an authentication token when loading/uploading from the Hugging Face Hub"
93
+ },
94
+ )
95
+ do_sample: Optional[bool] = field(default=True, metadata={"help": "Whether to use sampling mode for generation"})
96
+ temperature: Optional[float] = field(default=0.6, metadata={"help": "Temperature for sampling-based generation"})
97
+ max_new_tokens: Optional[int] = field(
98
+ default=256, metadata={"help": "Maximum number of new tokens during generation"}
99
+ )
100
+ torch_compile: Optional[bool] = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Whether to compile the forward pass (not sampling) in generate. Only compatible with Gemma and LlaMA."
104
+ },
105
+ )
106
+
107
+
108
+ @dataclass
109
+ class DataArguments:
110
+ """
111
+ Arguments pertaining to what data we are going to input our model for training and eval.
112
+ """
113
+
114
+ output_dir: str = field(
115
+ metadata={
116
+ "help": "Where to save the processed dataset to disk. If unspecified, uses a 'pretty' version of the "
117
+ "original dataset name. E.g. 'facebook/voxpopuli' will be saved under 'voxpopuli'."
118
+ },
119
+ )
120
+ dataset_name: str = field(
121
+ default=None,
122
+ metadata={"help": "The name of the dataset to use (via the datasets library)"},
123
+ )
124
+ dataset_config_name: Optional[str] = field(
125
+ default=None,
126
+ metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
127
+ )
128
+ dataset_split_name: Optional[str] = field(
129
+ default=None,
130
+ metadata={"help": "The split name of the dataset to use (via the datasets library)."},
131
+ )
132
+ dataset_cache_dir: Optional[str] = field(
133
+ default=None,
134
+ metadata={"help": "Path to cache directory for saving and loading datasets"},
135
+ )
136
+ max_eval_samples: Optional[int] = field(
137
+ default=None,
138
+ metadata={"help": "Maximum number of samples for generation - use for debugging purposes."},
139
+ )
140
+ overwrite_cache: bool = field(
141
+ default=False,
142
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
143
+ )
144
+ preprocessing_num_workers: Optional[int] = field(
145
+ default=None,
146
+ metadata={"help": "The number of processes to use for the preprocessing."},
147
+ )
148
+ dataloader_num_workers: Optional[int] = field(
149
+ default=0,
150
+ metadata={"help": "The number of processes to use for the dataloader."},
151
+ )
152
+ push_to_hub: Optional[bool] = field(
153
+ default=False,
154
+ metadata={"help": "Whether or not to push the processed dataset to the Hub."},
155
+ )
156
+ hub_dataset_id: Optional[str] = field(
157
+ default=None,
158
+ metadata={"help": "Repository namespace if pushing to the Hugging Face Hub."},
159
+ )
160
+ overwrite_output_dir: Optional[bool] = field(
161
+ default=False,
162
+ metadata={"help": "Overwrite the content of the output directory each time the script is run."},
163
+ )
164
+ save_steps: Optional[int] = field(
165
+ default=500,
166
+ metadata={"help": "Save the generated prompts every save_steps."},
167
+ )
168
+ save_total_limit: Optional[int] = field(
169
+ default=1, metadata={"help": ("If a value is passed, will limit the total number of saved checkpoints")}
170
+ )
171
+
172
+ def __post_init__(self):
173
+ if self.push_to_hub and self.hub_dataset_id is None:
174
+ raise ValueError("You must specify the `hub_dataset_id` when setting `--push_to_hub=True`")
175
+
176
+
177
+ def get_quantization_config(model_args: ModelArguments) -> Union[BitsAndBytesConfig, None]:
178
+ if model_args.load_in_4bit:
179
+ compute_dtype = torch.float16
180
+ if model_args.torch_dtype not in {"auto", None}:
181
+ compute_dtype = getattr(torch, model_args.torch_dtype)
182
+
183
+ quantization_config = BitsAndBytesConfig(
184
+ load_in_4bit=True,
185
+ bnb_4bit_compute_dtype=compute_dtype,
186
+ bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
187
+ bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
188
+ )
189
+ elif model_args.load_in_8bit:
190
+ quantization_config = BitsAndBytesConfig(
191
+ load_in_8bit=True,
192
+ )
193
+ else:
194
+ quantization_config = None
195
+
196
+ return quantization_config
197
+
198
+
199
+ def get_current_device() -> int:
200
+ """Get the current device. For GPU we return the local process index to enable multiple GPU training."""
201
+ return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
202
+
203
+
204
+ def get_kbit_device_map() -> Union[Dict[str, int], None]:
205
+ """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
206
+ return {"": get_current_device()} if torch.cuda.is_available() else None
207
+
208
+
209
+ CHECKPOINT_PREFIX = "checkpoint"
210
+ _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+).json$")
211
+
212
+
213
+ def save_checkpoint(output_dir, all_generated_ids, step):
214
+ checkpoint_path = f"{CHECKPOINT_PREFIX}-{step}.json"
215
+ output_path = os.path.join(output_dir, checkpoint_path)
216
+ all_generated_ids = [ids.tolist() for ids in all_generated_ids]
217
+ with open(output_path, "w") as file:
218
+ json.dump(all_generated_ids, file)
219
+
220
+
221
+ def load_checkpoint(checkpoint_path):
222
+ with open(checkpoint_path, "r") as file:
223
+ all_generated_ids = json.load(file)
224
+ all_generated_ids = [np.array(lst) for lst in all_generated_ids]
225
+ return all_generated_ids
226
+
227
+
228
+ def sorted_checkpoints(output_dir=None) -> List[str]:
229
+ """Helper function to sort saved checkpoints from oldest to newest."""
230
+ ordering_and_checkpoint_path = []
231
+
232
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{CHECKPOINT_PREFIX}-*")]
233
+
234
+ for path in glob_checkpoints:
235
+ regex_match = re.match(f".*{CHECKPOINT_PREFIX}-([0-9]+)", path)
236
+ if regex_match is not None and regex_match.groups() is not None:
237
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
238
+
239
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
240
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
241
+ return checkpoints_sorted
242
+
243
+
244
+ def rotate_checkpoints(save_total_limit=None, output_dir=None) -> None:
245
+ """Helper function to delete old checkpoints."""
246
+ if save_total_limit is None or save_total_limit <= 0:
247
+ return
248
+ # Check if we should delete older checkpoint(s)
249
+ checkpoints_sorted = sorted_checkpoints(output_dir=output_dir)
250
+ if len(checkpoints_sorted) <= save_total_limit:
251
+ return
252
+
253
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
254
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
255
+ for checkpoint in checkpoints_to_be_deleted:
256
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
257
+ os.remove(checkpoint)
258
+
259
+
260
+ def get_last_checkpoint(folder) -> Tuple[List, int]:
261
+ if not os.path.exists(folder) or not os.path.isdir(folder):
262
+ os.makedirs(folder, exist_ok=True)
263
+ return [], 0
264
+ content = os.listdir(folder)
265
+ checkpoints = [path for path in content if _RE_CHECKPOINT.search(path) is not None]
266
+ if len(checkpoints) == 0:
267
+ return [], 0
268
+ last_checkpoint = os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
269
+ # Find num steps saved state string pattern
270
+ pattern = r"checkpoint-(\d+).json"
271
+ match = re.search(pattern, last_checkpoint)
272
+ cur_step = int(match.group(1))
273
+ # load corresponding generated ids
274
+ all_generated_ids = load_checkpoint(last_checkpoint)
275
+ return all_generated_ids, cur_step
276
+
277
+
278
+ @dataclass
279
+ class DataCollatorWithPadding:
280
+ """
281
+ Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
282
+ """
283
+
284
+ tokenizer: Any
285
+
286
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
287
+ # split inputs and labels since they have to be of different lengths and need
288
+ # different padding methods
289
+ input_ids = {"input_ids": [feature["input_ids"] for feature in features]}
290
+ batch = self.tokenizer.pad(input_ids, return_tensors="pt", padding="longest", return_attention_mask=True)
291
+ return batch
292
+
293
+ id_to_name = {
294
+ "ex01": "Jerry",
295
+ "ex02": "Elisabeth",
296
+ "ex03": "Thomas",
297
+ "ex04": "Talia"
298
+ }
299
+
300
+ PROMPT = """You will be given a name and an enunciation style related to an audio sample of a person's speech.
301
+ 1. The name will be one of those: Jerry, Elisabeth, Thomas, Talia.
302
+ 2. The enunciation style will be one of those: 'enunciated', 'happy', 'confused', 'default', 'laughing', 'sad', 'whisper', 'emphasis'.
303
+ The enunciation style 'default' can be associated to 'with no particular emotion conveyed'.
304
+
305
+ Your task is to create a text description using these information that accurately describes the speech sample. Ensure that the generated description is grammatically correct, easy to understand, and most importantly, concise.
306
+
307
+ For example, given the following keywords: 'Talia', 'happy', a valid description would be: 'In an excellent recording, Talia speaks happily.'.
308
+ Another valid description would be: 'Talia delivers her words happily.'
309
+ Another example, given the following keywords: 'Jerry', 'emphasis': 'Jerry speaks with emphasis on certain words.'
310
+
311
+ You are free to change the order of th!e information, and replace synonymous terms.
312
+ You must give one and only one description and nothing else. Remember, I only want one description and nothing else.
313
+
314
+ For the information: '[speaker_id]', '[style]', the corresponding description is:"""
315
+
316
+
317
+ def main():
318
+ # 1. Parse input arguments
319
+ parser = HfArgumentParser((ModelArguments, DataArguments))
320
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
321
+ # If we pass only one argument to the script and it's the path to a json file,
322
+ # let's parse it to get our arguments.
323
+ model_args, data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
324
+ else:
325
+ model_args, data_args = parser.parse_args_into_dataclasses()
326
+
327
+ # 2. Setup logging
328
+ # Make one log on every process with the configuration for debugging.
329
+ logging.basicConfig(
330
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
331
+ datefmt="%m/%d/%Y %H:%M:%S",
332
+ handlers=[logging.StreamHandler(sys.stdout)],
333
+ )
334
+
335
+ accelerator = Accelerator()
336
+
337
+ if data_args.overwrite_output_dir and os.path.exists(data_args.output_dir) and os.path.isdir(data_args.output_dir):
338
+ logger.info("Cleaning output dir from previous run...")
339
+ shutil.rmtree(data_args.output_dir)
340
+
341
+ # 3. Load annotated dataset
342
+ logger.info("*** Load annotated dataset ***")
343
+ if data_args.dataset_split_name is not None:
344
+ raw_datasets = DatasetDict()
345
+ data_splits = data_args.dataset_split_name.split("+")
346
+ # load on a split-wise basis
347
+ for split in data_splits:
348
+ with accelerator.local_main_process_first():
349
+ raw_datasets[split] = load_dataset(
350
+ data_args.dataset_name,
351
+ data_args.dataset_config_name,
352
+ split=split,
353
+ cache_dir=model_args.cache_dir,
354
+ token=model_args.token,
355
+ num_proc=data_args.preprocessing_num_workers,
356
+ )
357
+ else:
358
+ with accelerator.local_main_process_first():
359
+ # load all splits for annotation
360
+ raw_datasets = load_dataset(
361
+ data_args.dataset_name,
362
+ data_args.dataset_config_name,
363
+ cache_dir=model_args.cache_dir,
364
+ token=model_args.token,
365
+ num_proc=data_args.preprocessing_num_workers,
366
+ )
367
+
368
+ raw_datasets_features = set(raw_datasets[next(iter(raw_datasets))].features.keys())
369
+
370
+ if data_args.max_eval_samples is not None:
371
+ for split in raw_datasets:
372
+ raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
373
+
374
+ # TODO(SG): add accent
375
+ EXPECTED_COLUMNS = {"speaker_id", "style"}
376
+ if not EXPECTED_COLUMNS.issubset(raw_datasets_features):
377
+ missing_columns = EXPECTED_COLUMNS - raw_datasets_features
378
+ raise ValueError(
379
+ f"Missing columns {missing_columns} from the dataset features. Got dataset features {raw_datasets_features}"
380
+ )
381
+
382
+ # 4. Load pre-trained model
383
+ logger.info("*** Load pretrained model ***")
384
+ torch_dtype = (
385
+ model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
386
+ )
387
+ quantization_config = get_quantization_config(model_args)
388
+
389
+ model = AutoModelForCausalLM.from_pretrained(
390
+ model_args.model_name_or_path,
391
+ revision=model_args.model_revision,
392
+ variant=model_args.model_variant,
393
+ trust_remote_code=model_args.trust_remote_code,
394
+ attn_implementation=model_args.attn_implementation,
395
+ torch_dtype=torch_dtype,
396
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
397
+ quantization_config=quantization_config,
398
+ low_cpu_mem_usage=True,
399
+ token=model_args.token,
400
+ ).eval()
401
+
402
+ if model_args.torch_compile:
403
+ # torch compile only compatible with gemma and llama
404
+ if not callable(getattr(model, "_setup_cache", None)):
405
+ raise ValueError(
406
+ f"Static k/v cache is not compatible with the model {model.__class__.__name__}. Set `--torch_compile=False"
407
+ "for dynamic k/v cache"
408
+ )
409
+ model.generation_config.cache_implementation = "static"
410
+ # compile the forward pass (but not the top-{p,k} sampling)
411
+ model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
412
+
413
+ tokenizer = AutoTokenizer.from_pretrained(
414
+ model_args.model_name_or_path,
415
+ revision=model_args.model_revision,
416
+ trust_remote_code=model_args.trust_remote_code,
417
+ use_fast=model_args.use_fast_tokenizer,
418
+ padding_side="left",
419
+ )
420
+ if tokenizer.pad_token_id is None:
421
+ tokenizer.pad_token_id = tokenizer.bos_token_id
422
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
423
+
424
+
425
+ def prepare_dataset(sample):
426
+ sample_prompt = PROMPT
427
+ sample["speaker_id"] = id_to_name[sample["speaker_id"]]
428
+ for key in EXPECTED_COLUMNS:
429
+ sample_prompt = sample_prompt.replace(f"[{key}]", sample[key])
430
+ sample_prompt = [{"role": "user", "content": sample_prompt}]
431
+ token_ids = tokenizer.apply_chat_template(sample_prompt)
432
+ sample["input_ids"] = token_ids
433
+ return sample
434
+
435
+ with accelerator.local_main_process_first():
436
+ vectorized_datasets = raw_datasets.map(
437
+ prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preparing prompts"
438
+ )
439
+
440
+ # Prepare everything with our `accelerator`
441
+ model = accelerator.prepare(model)
442
+ data_collator = DataCollatorWithPadding(tokenizer)
443
+
444
+ def generate_step(batch):
445
+ output_ids = accelerator.unwrap_model(model).generate(
446
+ batch["input_ids"],
447
+ attention_mask=batch["attention_mask"],
448
+ do_sample=model_args.do_sample,
449
+ temperature=model_args.temperature,
450
+ max_new_tokens=model_args.max_new_tokens,
451
+ )
452
+ output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
453
+ return output_ids
454
+
455
+ def postprocess_dataset(sample):
456
+ prompt_text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True)
457
+ generated_text = tokenizer.decode(sample["generated_ids"], skip_special_tokens=True)
458
+ sample["text_description"] = generated_text[len(prompt_text) :]
459
+ return sample
460
+
461
+ for split in vectorized_datasets:
462
+ data_loader = DataLoader(
463
+ vectorized_datasets[split],
464
+ batch_size=model_args.per_device_eval_batch_size,
465
+ collate_fn=data_collator,
466
+ num_workers=data_args.dataloader_num_workers,
467
+ pin_memory=True,
468
+ )
469
+ data_loader = accelerator.prepare(data_loader)
470
+ total_inference_steps = len(data_loader)
471
+ progress_bar = tqdm(
472
+ range(total_inference_steps), desc=" ... ", position=0, disable=not accelerator.is_local_main_process
473
+ )
474
+
475
+ split_output_dir = os.path.join(data_args.output_dir, split)
476
+ all_generated_ids, cur_step = get_last_checkpoint(split_output_dir)
477
+
478
+ if cur_step > 0:
479
+ logger.info(f"Resuming {split} from step {cur_step}")
480
+ # efficiently skip the first n batches
481
+ data_loader = skip_first_batches(data_loader, cur_step)
482
+ progress_bar.update(cur_step)
483
+
484
+ while cur_step < total_inference_steps:
485
+ for batch in data_loader:
486
+ generated_ids = generate_step(batch)
487
+ generated_ids = accelerator.gather_for_metrics(generated_ids)
488
+ all_generated_ids.extend(generated_ids.cpu().numpy())
489
+
490
+ cur_step += 1
491
+ progress_bar.update(1)
492
+
493
+ if (cur_step % data_args.save_steps == 0) or (cur_step == total_inference_steps):
494
+ save_checkpoint(split_output_dir, all_generated_ids, cur_step)
495
+ rotate_checkpoints(data_args.save_total_limit, output_dir=split_output_dir)
496
+
497
+ vectorized_datasets[split] = vectorized_datasets[split].add_column("generated_ids", all_generated_ids)
498
+
499
+ if accelerator.is_main_process:
500
+ vectorized_datasets[split] = vectorized_datasets[split].map(
501
+ postprocess_dataset,
502
+ num_proc=data_args.preprocessing_num_workers,
503
+ desc="Postprocessing dataset",
504
+ remove_columns=["input_ids", "generated_ids"],
505
+ )
506
+ accelerator.wait_for_everyone()
507
+
508
+ if accelerator.is_main_process:
509
+ vectorized_datasets.save_to_disk(data_args.output_dir)
510
+ if data_args.push_to_hub:
511
+ vectorized_datasets.push_to_hub(
512
+ data_args.hub_dataset_id,
513
+ config_name=data_args.dataset_config_name if data_args.dataset_config_name is not None else "default",
514
+ token=model_args.token,
515
+ )
516
+ accelerator.wait_for_everyone()
517
+ accelerator.end_training()
518
+
519
+
520
+ if __name__ == "__main__":
521
+ main()