AlexHung29629
commited on
Update ultravox_processing.py
Browse files- ultravox_processing.py +21 -20
ultravox_processing.py
CHANGED
@@ -163,28 +163,29 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
163 |
processed_text = []
|
164 |
data["audio_token_start_idx"] = []
|
165 |
for t in text:
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
)
|
171 |
-
|
172 |
-
start_idx = len(
|
173 |
-
self.tokenizer.encode(
|
174 |
-
t.split(self.audio_placeholder)[0],
|
175 |
-
add_special_tokens=False,
|
176 |
-
)
|
177 |
)
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
t = t.replace(
|
184 |
-
self.audio_placeholder,
|
185 |
-
self.audio_token_replacement * audio_embed_frames,
|
186 |
)
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
# Special tokens like BOS should already have been added by the caller.
|
190 |
data.update(self.tokenizer(processed_text, add_special_tokens=False, padding='longest', **kwargs))
|
|
|
163 |
processed_text = []
|
164 |
data["audio_token_start_idx"] = []
|
165 |
for t in text:
|
166 |
+
assert self.audio_placeholder in t
|
167 |
+
if "audio_token_len" not in data:
|
168 |
+
raise ValueError(
|
169 |
+
f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
)
|
171 |
+
|
172 |
+
start_idx = len(
|
173 |
+
self.tokenizer.encode(
|
174 |
+
t.split(self.audio_placeholder)[0],
|
175 |
+
add_special_tokens=False,
|
|
|
|
|
|
|
176 |
)
|
177 |
+
)
|
178 |
+
data["audio_token_start_idx"].append(start_idx)
|
179 |
+
|
180 |
+
# Replace the audio placeholder with the audio token.
|
181 |
+
# e.g. "Transcribe <|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
|
182 |
+
# where the number of </s> is the number of audio frames.
|
183 |
+
t = t.replace(
|
184 |
+
self.audio_placeholder,
|
185 |
+
self.audio_token_replacement * audio_embed_frames,
|
186 |
+
)
|
187 |
+
processed_text.append(t)
|
188 |
+
|
189 |
|
190 |
# Special tokens like BOS should already have been added by the caller.
|
191 |
data.update(self.tokenizer(processed_text, add_special_tokens=False, padding='longest', **kwargs))
|