Does it support fast XLA text generation ?
#5
by
bakrianoo
- opened
Hi
I am wondering if it can support using the new XLA text generation.
I followed this blog: https://huggingface.co/blog/tf-xla-generate
And used the following code
from transformers import AutoModelForSeq2SeqLM , AutoTokenizer
import tensorflow as tf
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
ph_model_name = "tuner007/pegasus_paraphrase"
# torch_device = "cuda:0"
ph_tokenizer = AutoTokenizer.from_pretrained(ph_model_name)
ph_model = AutoModelForSeq2SeqLM.from_pretrained(ph_model_name)
tokenization_kwargs = {"max_length": 512, "padding": "longest", "truncation": True, "return_tensors": "tf"}
generation_kwargs = {"num_beams": 7, "max_length": 512,
"num_return_sequences":2, "temperature":0.7,
"do_sample": True, "top_k": 90, "top_p": 0.95,
"no_repeat_ngram_size": 2, "early_stopping": True}
# generate a paraphrased text
xla_generate = tf.function(ph_model.generate, jit_compile=True)
input_prompt = 'the world has been inching toward fully autonomous cars for years .'
tokenized_inputs = ph_tokenizer([input_prompt], **tokenization_kwargs)
generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
decoded_text = ph_tokenizer.decode(generated_text[0], skip_special_tokens=True)
print(decoded_text)
but got this error message
TypeError Traceback (most recent call last)
<ipython-input-8-a25ee9d320ec> in <module>
1 input_prompt = 'the world has been inching toward fully autonomous cars for years .'
2 tokenized_inputs = ph_tokenizer([input_prompt], **tokenization_kwargs)
----> 3 generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
4 decoded_text = ph_tokenizer.decode(generated_text[0], skip_special_tokens=True)
5 print(decoded_text)
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
1145 except Exception as e: # pylint:disable=broad-except
1146 if hasattr(e, "ag_error_metadata"):
-> 1147 raise e.ag_error_metadata.to_exception(e)
1148 else:
1149 raise
TypeError: in user code:
File "/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py", line 847, in decorate_context *
return func(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py", line 1182, in generate *
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
File "/usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py", line 525, in _prepare_encoder_decoder_kwargs_for_generation *
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl *
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/transformers/models/pegasus/modeling_pegasus.py", line 753, in forward *
input_shape = input_ids.size()
TypeError: 'numpy.int64' object is not callable
Any help?