File size: 5,582 Bytes
82bf4de 1e063fc 82bf4de 1e063fc 82bf4de 1e063fc 82bf4de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import logging
import os
import pandas as pd
import random
import re
import sys
import time
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Callable, Optional
import jax
import jax.numpy as jnp
from filelock import FileLock
from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from transformers import FlaxAutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from datasets import Dataset, load_dataset, load_metric
from tqdm import tqdm
import pandas as pd
print(jax.devices())
MODEL_NAME_OR_PATH = "../"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
prefix = "items: "
text_column = "inputs"
target_column = "targets"
max_source_length = 256
max_target_length = 1024
seed = 42
eval_batch_size = 64
# generation_kwargs = {
# "max_length": 1024,
# "min_length": 128,
# "no_repeat_ngram_size": 3,
# "do_sample": True,
# "top_k": 60,
# "top_p": 0.95
# }
generation_kwargs = {
"max_length": 1024,
"min_length": 64,
"no_repeat_ngram_size": 3,
"early_stopping": True,
"num_beams": 4,
"length_penalty": 1.5,
}
special_tokens = tokenizer.all_special_tokens
tokens_map = {
"<sep>": "--",
"<section>": "\n"
}
def skip_special_tokens(text, special_tokens):
for token in special_tokens:
text = text.replace(token, '')
return text
def target_postprocessing(texts, special_tokens):
if not isinstance(texts, list):
texts = [texts]
new_texts = []
for text in texts:
text = skip_special_tokens(text, special_tokens)
for k, v in tokens_map.items():
text = text.replace(k, v)
new_texts.append(text)
return new_texts
predict_dataset = load_dataset("csv", data_files={"test": "/home/m3hrdadfi/code/data/test.csv"}, delimiter="\t")["test"]
print(predict_dataset)
# predict_dataset = predict_dataset.select(range(10))
# print(predict_dataset)
column_names = predict_dataset.column_names
print(column_names)
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(examples):
inputs = examples[text_column]
targets = examples[target_column]
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(
inputs,
max_length=max_source_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(
targets,
max_length=max_target_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
predict_dataset = predict_dataset.map(
preprocess_function,
batched=True,
num_proc=None,
remove_columns=column_names,
desc="Running tokenizer on prediction dataset",
)
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
"""
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
Shuffle batches if `shuffle` is `True`.
"""
steps_per_epoch = len(dataset) // batch_size
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
else:
batch_idx = jnp.arange(len(dataset))
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
for idx in batch_idx:
batch = dataset[idx]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
rng, input_rng = jax.random.split(rng)
def generate_step(batch):
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **generation_kwargs)
return output_ids.sequences
p_generate_step = jax.pmap(generate_step, "batch")
pred_generations = []
pred_labels = []
pred_inputs = []
pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
pred_steps = len(predict_dataset) // eval_batch_size
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
# Model forward
batch = next(pred_loader)
inputs = batch["input_ids"]
labels = batch["labels"]
generated_ids = p_generate_step(batch)
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, generation_kwargs["max_length"])))
pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
pred_inputs.extend(jax.device_get(inputs.reshape(-1, inputs.shape[-1])))
inputs = tokenizer.batch_decode(pred_inputs, skip_special_tokens=True)
true_recipe = target_postprocessing(
tokenizer.batch_decode(pred_labels, skip_special_tokens=False),
special_tokens
)
generated_recipe = target_postprocessing(
tokenizer.batch_decode(pred_generations, skip_special_tokens=False),
special_tokens
)
test_output = {
"inputs": inputs,
"true_recipe": true_recipe,
"generated_recipe": generated_recipe
}
test_output = pd.DataFrame.from_dict(test_output)
test_output.to_csv("./generated_recipes_b.csv", sep="\t", index=False, encoding="utf-8")
|