Update RepoPipeline.py
Browse files- RepoPipeline.py +2 -2
RepoPipeline.py
CHANGED
@@ -157,9 +157,9 @@ class RepoPipeline(Pipeline):
|
|
157 |
|
158 |
def generate_embeddings(self, text_sets, max_length):
|
159 |
assert max_length < 1024
|
160 |
-
return torch.
|
161 |
if text_sets is None or len(text_sets) == 0 \
|
162 |
-
else torch.
|
163 |
|
164 |
def _forward(self, extracted_infos: List, max_length=512) -> List:
|
165 |
model_outputs = []
|
|
|
157 |
|
158 |
def generate_embeddings(self, text_sets, max_length):
|
159 |
assert max_length < 1024
|
160 |
+
return torch.zeros((1, 768), device=self.device) \
|
161 |
if text_sets is None or len(text_sets) == 0 \
|
162 |
+
else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
|
163 |
|
164 |
def _forward(self, extracted_infos: List, max_length=512) -> List:
|
165 |
model_outputs = []
|