Spaces:
Runtime error
Runtime error
File size: 1,032 Bytes
5e95a58 16b0970 5e95a58 68da745 5e95a58 16b0970 5e95a58 cde7ed6 5e95a58 4820fa1 cde7ed6 4820fa1 cde7ed6 5e95a58 cde7ed6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
from transformers import AutoTokenizer, VisionEncoderDecoderModel
class Inference:
def __init__(self, decoder_model_name, max_length=32):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
self.encoder_decoder_model = VisionEncoderDecoderModel.from_pretrained('armgabrielyan/video-summarization')
self.encoder_decoder_model.to(self.device)
self.max_length = max_length
def generate_texts(self, pixel_values):
if not self.tokenizer.pad_token:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer))
generated_ids = self.encoder_decoder_model.generate(
pixel_values.to(self.device),
max_length=self.max_length,
num_beams=4,
no_repeat_ngram_size=2,
)
generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return generated_texts
|