File size: 1,032 Bytes
5e95a58
16b0970
5e95a58
 
 
68da745
5e95a58
 
 
16b0970
5e95a58
 
 
 
cde7ed6
5e95a58
 
 
 
4820fa1
cde7ed6
4820fa1
 
 
 
cde7ed6
5e95a58
cde7ed6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from transformers import AutoTokenizer, VisionEncoderDecoderModel


class Inference:
  def __init__(self, decoder_model_name, max_length=32):
    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
    self.encoder_decoder_model = VisionEncoderDecoderModel.from_pretrained('armgabrielyan/video-summarization')
    self.encoder_decoder_model.to(self.device)

    self.max_length = max_length

  def generate_texts(self, pixel_values):
    if not self.tokenizer.pad_token:
      self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
      self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer))

    generated_ids = self.encoder_decoder_model.generate(
      pixel_values.to(self.device),
      max_length=self.max_length,
      num_beams=4, 
      no_repeat_ngram_size=2,
    )
    generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    return generated_texts