TFBartForConditionalGeneration does not work with XLA compiler
Hi,
I followed this Huggingface blogpost to accelerate the performance of text generation using TF with XLA. On wrappingmodel.generate
with xla_generate = tf.function(model.generate, jit_compile=True)
, we get the following error:
```
NotImplementedError Traceback (most recent call last)
in ()
26 tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
27 start = time.time_ns()
---> 28 generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
29 end = time.time_ns()
30 decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
1145 except Exception as e: # pylint:disable=broad-except
1146 if hasattr(e, "ag_error_metadata"):
-> 1147 raise e.ag_error_metadata.to_exception(e)
1148 else:
1149 raise
NotImplementedError: in user code:
File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_utils.py", line 605, in generate *
seed=model_kwargs.pop("seed", None),
File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_utils.py", line 1687, in _generate *
input_ids,
File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_utils.py", line 2854, in beam_search_body_fn *
log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_logits_process.py", line 94, in __call__ *
scores = processor(input_ids, scores, cur_len)
File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_logits_process.py", line 427, in __call__ *
raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")
NotImplementedError: TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.
Example code for reproducing this error:
Stand-alone TF XLA generate example for Decoder-Only Models.
Note: execution times are deeply dependent on hardware.
If you have a machine with a powerful GPU, I highly recommend you to try this example there!
import time
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
1. Load model and tokenizer
model_name = "facebook/bart-large-cnn"
remember: decoder-only models need left-padding
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", pad_token="")
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
2. Prepare tokenization and generation arguments -- don't forget padding to avoid retracing!
tokenization_kwargs = {"pad_to_multiple_of": 32, "padding": True, "return_tensors": "tf"}
generation_kwargs = {"num_beams": 4, "max_new_tokens": 64}
3. Create your XLA generate function a̶n̶d̶ ̶m̶a̶k̶e̶ ̶P̶y̶T̶o̶r̶c̶h̶ ̶e̶a̶t̶ ̶d̶u̶s̶t̶
This is the only change with respect to original generate workflow!
xla_generate = tf.function(model.generate, jit_compile=True)
4. Generate! Remember -- the first call will be slow, but all subsequent calls will be fast if you've done things right.
input_prompts = [f"The best thing about {country} is" for country in ["Spain", "Japan", "Angola"]]
for input_prompt in input_prompts:
tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
start = time.time_ns()
generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
end = time.time_ns()
decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
print(f"Original prompt -- {input_prompt}")
print(f"Generated -- {decoded_text}")
print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")
Other models work fine and the issue seems to be with BART. Is BART not supported to work with XLA?
tagging @patrickvonplaten to get your opinion on this according to the docs.