import argparse
import json
import nltk
import time
import os
import tqdm
from nltk.tokenize import sent_tokenize
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
nltk.download("punkt")
def generate_dipper_paraphrases(
data,
model_name="kalpeshk2011/dipper-paraphraser-xxl",
no_ctx=True,
sent_interval=3,
start_idx=None,
end_idx=None,
paraphrase_file=".output/dipper_attacks.jsonl",
lex=20,
order=0,
args=None,
):
if no_ctx:
paraphrase_file = paraphrase_file.split(".jsonl")[0] + "_no_ctx" + ".jsonl"
if sent_interval == 1:
paraphrase_file = paraphrase_file.split(".jsonl")[0] + "_sent" + ".jsonl"
output_file = (
paraphrase_file.split(".jsonl")[0]
+ "_L_"
+ f"{lex}"
+ "_O_"
+ f"{order}"
+ "_pp"
+ ".jsonl"
)
if os.path.exists(output_file):
with open(output_file, "r") as f:
num_output_points = len([json.loads(x) for x in f.read().strip().split("\n")])
else:
num_output_points = 0
print(f"Skipping {num_output_points} points")
time1 = time.time()
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Model loaded in ", time.time() - time1)
# model.half()
model.cuda()
model.eval()
data = (
data.select(range(0, len(data)))
if start_idx is None or end_idx is None
else data.select(range(start_idx, end_idx))
)
# iterate over data and tokenize each instance
w_wm_output_attacked = []
dipper_inputs = []
for idx, dd in tqdm.tqdm(enumerate(data), total=len(data)):
if idx < num_output_points:
continue
# tokenize prefix
if "w_wm_output_attacked" not in dd:
# paraphrase_outputs = {}
if args.no_wm_attack:
if isinstance(dd["no_wm_output"], str):
input_gen = dd["no_wm_output"].strip()
else:
input_gen = dd["no_wm_output"][0].strip()
else:
if isinstance(dd["w_wm_output"], str):
input_gen = dd["w_wm_output"].strip()
else:
input_gen = dd["w_wm_output"][0].strip()
# The lexical and order diversity codes used by the actual model correspond to "similarity" rather than "diversity".
# Thus, for a diversity measure of X, we need to use control code value of 100 - X.
lex_code = int(100 - lex)
order_code = int(100 - order)
# remove spurious newlines
input_gen = " ".join(input_gen.split())
sentences = sent_tokenize(input_gen)
prefix = " ".join(dd["truncated_input"].replace("\n", " ").split())
output_text = ""
final_input_text = ""
for sent_idx in range(0, len(sentences), sent_interval):
curr_sent_window = " ".join(sentences[sent_idx : sent_idx + sent_interval])
if no_ctx:
final_input_text = f"lexical = {lex_code}, order = {order_code} {curr_sent_window} "
else:
final_input_text = f"lexical = {lex_code}, order = {order_code} {prefix} {curr_sent_window} "
if idx == 0 and lex_code == 60 and order_code == 60:
print(final_input_text)
final_input = tokenizer([final_input_text], return_tensors="pt")
final_input = {k: v.cuda() for k, v in final_input.items()}
with torch.inference_mode():
outputs = model.generate(
**final_input, do_sample=True, top_p=0.75, top_k=None, max_length=512
)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
prefix += " " + outputs[0]
output_text += " " + outputs[0]
# paraphrase_outputs[f"lex_{lex_code}_order_{order_code}"] = {
# "final_input": final_input_text,
# "output": [output_text.strip()],
# "lex": lex_code,
# "order": order_code
# }
# dd["w_wm_output_attacked"] = paraphrase_outputs
w_wm_output_attacked.append(output_text.strip())
dipper_inputs.append(final_input_text)
# with open(output_file, "a") as f:
# f.write(json.dumps(dd) + "\n")
# add w_wm_output_attacked to hf dataset object as a column
data = data.add_column("w_wm_output_attacked", w_wm_output_attacked)
data = data.add_column(f"dipper_inputs_Lex{lex}_Order{order}", dipper_inputs)
return data