global attention mask for model inference
hey, thanks for the great upload! my question is related to model inference.
I was able to train using the trainer API on Colab, but in trying to run some inference in a notebook on non-dataset text, I get the following error even with LongT5ForConditionalGeneration
:
TypeError: forward() got an unexpected keyword argument 'global_attention_mask'
(perhaps unrelated, but that same notebook works fine with this checkpoint for example). I then turned to your example in the model card to see if I could replicate that and I am doing something wrong, but I can't find where you define global_attention_mask
in the example:
import torch
from transformers import AutoTokenizer, LongT5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
input_ids = tokenizer(LONG_ARTICLE, return_tensors="pt").input_ids.to("cuda")
model = LongT5ForConditionalGeneration.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps", return_dict_in_generate=True).to("cuda")
sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences
summary = tokenizer.batch_decode(sequences)
any help would be appreciated :)
Hi, thank you very much for pointing out this issue. It's a mistake on my side. LongT5
model accepts attention_mask
, not global_attentiona_mask
.. Sorry for the confusion, a I'll fix that :]
Thanks! Works as expected now 👍
I’ll post my checkpoints once I’m happy with the performance but DAMN this thing takes forever to train (even compared to LED at 16384)