|
import logging |
|
import os |
|
import pandas as pd |
|
import random |
|
import re |
|
import sys |
|
import time |
|
from dataclasses import dataclass, field |
|
from functools import partial |
|
from pathlib import Path |
|
from typing import Callable, Optional |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
|
|
from filelock import FileLock |
|
from flax import jax_utils, traverse_util |
|
from flax.jax_utils import unreplicate |
|
from flax.training import train_state |
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key |
|
|
|
from transformers import FlaxAutoModelForSeq2SeqLM |
|
from transformers import AutoTokenizer |
|
|
|
from datasets import Dataset, load_dataset, load_metric |
|
from tqdm import tqdm |
|
import pandas as pd |
|
|
|
|
|
print(jax.devices()) |
|
|
|
MODEL_NAME_OR_PATH = "../" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) |
|
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) |
|
|
|
prefix = "items: " |
|
text_column = "inputs" |
|
target_column = "targets" |
|
max_source_length = 256 |
|
max_target_length = 1024 |
|
seed = 42 |
|
eval_batch_size = 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
"max_length": 1024, |
|
"min_length": 64, |
|
"no_repeat_ngram_size": 3, |
|
"early_stopping": True, |
|
"num_beams": 4, |
|
"length_penalty": 1.5, |
|
} |
|
|
|
special_tokens = tokenizer.all_special_tokens |
|
tokens_map = { |
|
"<sep>": "--", |
|
"<section>": "\n" |
|
} |
|
def skip_special_tokens(text, special_tokens): |
|
for token in special_tokens: |
|
text = text.replace(token, '') |
|
|
|
return text |
|
|
|
def target_postprocessing(texts, special_tokens): |
|
if not isinstance(texts, list): |
|
texts = [texts] |
|
|
|
new_texts = [] |
|
for text in texts: |
|
text = skip_special_tokens(text, special_tokens) |
|
|
|
for k, v in tokens_map.items(): |
|
text = text.replace(k, v) |
|
|
|
new_texts.append(text) |
|
|
|
return new_texts |
|
|
|
|
|
predict_dataset = load_dataset("csv", data_files={"test": "/home/m3hrdadfi/code/data/test.csv"}, delimiter="\t")["test"] |
|
print(predict_dataset) |
|
|
|
|
|
column_names = predict_dataset.column_names |
|
print(column_names) |
|
|
|
|
|
|
|
def preprocess_function(examples): |
|
inputs = examples[text_column] |
|
targets = examples[target_column] |
|
inputs = [prefix + inp for inp in inputs] |
|
model_inputs = tokenizer( |
|
inputs, |
|
max_length=max_source_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="np" |
|
) |
|
|
|
|
|
with tokenizer.as_target_tokenizer(): |
|
labels = tokenizer( |
|
targets, |
|
max_length=max_target_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="np" |
|
) |
|
|
|
model_inputs["labels"] = labels["input_ids"] |
|
|
|
return model_inputs |
|
|
|
predict_dataset = predict_dataset.map( |
|
preprocess_function, |
|
batched=True, |
|
num_proc=None, |
|
remove_columns=column_names, |
|
desc="Running tokenizer on prediction dataset", |
|
) |
|
|
|
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): |
|
""" |
|
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. |
|
Shuffle batches if `shuffle` is `True`. |
|
""" |
|
steps_per_epoch = len(dataset) // batch_size |
|
|
|
if shuffle: |
|
batch_idx = jax.random.permutation(rng, len(dataset)) |
|
else: |
|
batch_idx = jnp.arange(len(dataset)) |
|
|
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] |
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) |
|
|
|
for idx in batch_idx: |
|
batch = dataset[idx] |
|
batch = {k: jnp.array(v) for k, v in batch.items()} |
|
|
|
batch = shard(batch) |
|
|
|
yield batch |
|
|
|
rng = jax.random.PRNGKey(seed) |
|
rng, dropout_rng = jax.random.split(rng) |
|
rng, input_rng = jax.random.split(rng) |
|
|
|
def generate_step(batch): |
|
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **generation_kwargs) |
|
return output_ids.sequences |
|
|
|
p_generate_step = jax.pmap(generate_step, "batch") |
|
|
|
pred_generations = [] |
|
pred_labels = [] |
|
pred_inputs = [] |
|
pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size) |
|
pred_steps = len(predict_dataset) // eval_batch_size |
|
|
|
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False): |
|
|
|
batch = next(pred_loader) |
|
inputs = batch["input_ids"] |
|
labels = batch["labels"] |
|
|
|
generated_ids = p_generate_step(batch) |
|
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, generation_kwargs["max_length"]))) |
|
pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) |
|
pred_inputs.extend(jax.device_get(inputs.reshape(-1, inputs.shape[-1]))) |
|
|
|
inputs = tokenizer.batch_decode(pred_inputs, skip_special_tokens=True) |
|
true_recipe = target_postprocessing( |
|
tokenizer.batch_decode(pred_labels, skip_special_tokens=False), |
|
special_tokens |
|
) |
|
generated_recipe = target_postprocessing( |
|
tokenizer.batch_decode(pred_generations, skip_special_tokens=False), |
|
special_tokens |
|
) |
|
test_output = { |
|
"inputs": inputs, |
|
"true_recipe": true_recipe, |
|
"generated_recipe": generated_recipe |
|
} |
|
test_output = pd.DataFrame.from_dict(test_output) |
|
test_output.to_csv("./generated_recipes_b.csv", sep="\t", index=False, encoding="utf-8") |
|
|