Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, VisionEncoderDecoderModel | |
class Inference: | |
def __init__(self, decoder_model_name, max_length=32): | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name) | |
self.encoder_decoder_model = VisionEncoderDecoderModel.from_pretrained('armgabrielyan/video-summarization') | |
self.encoder_decoder_model.to(self.device) | |
self.max_length = max_length | |
def generate_texts(self, pixel_values): | |
if not self.tokenizer.pad_token: | |
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer)) | |
generated_ids = self.encoder_decoder_model.generate( | |
pixel_values.to(self.device), | |
max_length=self.max_length, | |
num_beams=4, | |
no_repeat_ngram_size=2, | |
) | |
generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
return generated_texts | |