Henry65 commited on
Commit
2955c59
1 Parent(s): 134c3f4

Update RepoPipeline.py

Browse files
Files changed (1) hide show
  1. RepoPipeline.py +2 -2
RepoPipeline.py CHANGED
@@ -157,9 +157,9 @@ class RepoPipeline(Pipeline):
157
 
158
  def generate_embeddings(self, text_sets, max_length):
159
  assert max_length < 1024
160
- return torch.concat([self.encode(text, max_length) for text in text_sets], dim=0) \
161
  if text_sets is None or len(text_sets) == 0 \
162
- else torch.zeros((1, 768), device=self.device)
163
 
164
  def _forward(self, extracted_infos: List, max_length=512) -> List:
165
  model_outputs = []
 
157
 
158
  def generate_embeddings(self, text_sets, max_length):
159
  assert max_length < 1024
160
+ return torch.zeros((1, 768), device=self.device) \
161
  if text_sets is None or len(text_sets) == 0 \
162
+ else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
163
 
164
  def _forward(self, extracted_infos: List, max_length=512) -> List:
165
  model_outputs = []