osanseviero
commited on
Commit
•
b00ee09
1
Parent(s):
37a3159
Update pipeline.py
Browse files- pipeline.py +2 -1
pipeline.py
CHANGED
@@ -3,6 +3,7 @@ from typing import Dict, List, Tuple
|
|
3 |
import numpy as np
|
4 |
from asteroid import separate
|
5 |
from asteroid.models import BaseModel
|
|
|
6 |
|
7 |
|
8 |
class PreTrainedPipeline():
|
@@ -11,7 +12,7 @@ class PreTrainedPipeline():
|
|
11 |
# Preload all the elements you are going to need at inference.
|
12 |
# For instance your model, processors, tokenizer that might be needed.
|
13 |
# This function is only called once, so do all the heavy processing I/O here"""
|
14 |
-
self.model = BaseModel.from_pretrained(path)
|
15 |
self.sampling_rate = self.model.sample_rate
|
16 |
|
17 |
def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
|
|
|
3 |
import numpy as np
|
4 |
from asteroid import separate
|
5 |
from asteroid.models import BaseModel
|
6 |
+
import os
|
7 |
|
8 |
|
9 |
class PreTrainedPipeline():
|
|
|
12 |
# Preload all the elements you are going to need at inference.
|
13 |
# For instance your model, processors, tokenizer that might be needed.
|
14 |
# This function is only called once, so do all the heavy processing I/O here"""
|
15 |
+
self.model = BaseModel.from_pretrained(os.path.join(path, "pytorch_model.bin"))
|
16 |
self.sampling_rate = self.model.sample_rate
|
17 |
|
18 |
def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
|