Spaces:
Runtime error
Runtime error
Armen Gabrielyan
commited on
Commit
•
68da745
1
Parent(s):
5e95a58
add video summarization pre-trained model
Browse files- app.py +0 -2
- inference.py +3 -3
app.py
CHANGED
@@ -8,14 +8,12 @@ import numpy as np
|
|
8 |
from inference import Inference
|
9 |
import utils
|
10 |
|
11 |
-
model_checkpoint = 'saved_model'
|
12 |
encoder_model_name = 'google/vit-large-patch32-224-in21k'
|
13 |
decoder_model_name = 'gpt2'
|
14 |
frame_step = 300
|
15 |
|
16 |
inference = Inference(
|
17 |
decoder_model_name=decoder_model_name,
|
18 |
-
model_checkpoint=model_checkpoint,
|
19 |
)
|
20 |
|
21 |
model = SentenceTransformer('all-mpnet-base-v2')
|
|
|
8 |
from inference import Inference
|
9 |
import utils
|
10 |
|
|
|
11 |
encoder_model_name = 'google/vit-large-patch32-224-in21k'
|
12 |
decoder_model_name = 'gpt2'
|
13 |
frame_step = 300
|
14 |
|
15 |
inference = Inference(
|
16 |
decoder_model_name=decoder_model_name,
|
|
|
17 |
)
|
18 |
|
19 |
model = SentenceTransformer('all-mpnet-base-v2')
|
inference.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
import torch
|
2 |
-
from transformers import AutoTokenizer,
|
3 |
|
4 |
import utils
|
5 |
|
6 |
class Inference:
|
7 |
-
def __init__(self, decoder_model_name,
|
8 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
|
10 |
self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
|
11 |
-
self.encoder_decoder_model =
|
12 |
self.encoder_decoder_model.to(self.device)
|
13 |
|
14 |
self.max_length = max_length
|
|
|
1 |
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModel
|
3 |
|
4 |
import utils
|
5 |
|
6 |
class Inference:
|
7 |
+
def __init__(self, decoder_model_name, max_length=32):
|
8 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
|
10 |
self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
|
11 |
+
self.encoder_decoder_model = AutoModel.from_pretrained('armgabrielyan/video-summarization')
|
12 |
self.encoder_decoder_model.to(self.device)
|
13 |
|
14 |
self.max_length = max_length
|