|
|
|
|
|
|
|
|
|
import logging |
|
from typing import Any, Dict, List, Optional |
|
|
|
import numpy as np |
|
import transformers |
|
|
|
|
|
|
|
from .shuka_model import ShukaModel |
|
from .shuka_processing import ShukaProcessor |
|
|
|
|
|
class ShukaPipeline(transformers.Pipeline): |
|
def __init__( |
|
self, |
|
model: ShukaModel, |
|
tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None, |
|
audio_processor: Optional[transformers.ProcessorMixin] = None, |
|
**kwargs |
|
): |
|
if tokenizer is None: |
|
try: |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model.config._name_or_path |
|
) |
|
except: |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model.config.text_model_id or model.config.text_config._name_or_path |
|
) |
|
|
|
if audio_processor is None: |
|
audio_processor = transformers.AutoProcessor.from_pretrained( |
|
model.config.audio_model_id or model.config.audio_config._name_or_path |
|
) |
|
|
|
self.processor = ShukaProcessor( |
|
audio_processor=audio_processor, |
|
tokenizer=tokenizer, |
|
stack_factor=model.config.stack_factor, |
|
) |
|
|
|
super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"] |
|
generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys} |
|
return {}, generation_kwargs, {} |
|
|
|
def preprocess(self, inputs: Dict[str, Any]): |
|
turns: list = inputs.get("turns", []) |
|
|
|
audio = inputs.get("audio", None) |
|
|
|
if isinstance(audio, np.ndarray): |
|
if audio.dtype == np.float64: |
|
audio = audio.astype(np.float32) |
|
elif audio.dtype == np.int16: |
|
audio = audio.astype(np.float32) / np.float32(32768.0) |
|
elif audio.dtype == np.int32: |
|
audio = audio.astype(np.float32) / np.float32(2147483648.0) |
|
|
|
if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"): |
|
prompt = inputs.get("prompt", "<|audio|>") |
|
if "<|audio|>" not in prompt: |
|
logging.warning( |
|
"Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt." |
|
) |
|
|
|
prompt += " <|audio|>" |
|
turns.append({"role": "user", "content": prompt}) |
|
|
|
text = self.processor.tokenizer.apply_chat_template( |
|
turns, add_generation_prompt=True, tokenize=False |
|
) |
|
|
|
if "sampling_rate" not in inputs and audio is not None: |
|
logging.warning( |
|
"No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate." |
|
) |
|
|
|
output = self.processor( |
|
text=text, |
|
audio=audio, |
|
sampling_rate=inputs.get("sampling_rate", 16000), |
|
) |
|
if "audio_values" in output: |
|
output["audio_values"] = output["audio_values"].to(self.model.dtype) |
|
|
|
return output |
|
|
|
def _forward( |
|
self, |
|
model_inputs: Dict[str, Any], |
|
temperature: Optional[float] = None, |
|
max_new_tokens: Optional[int] = None, |
|
repetition_penalty: float = 1.1, |
|
) -> List[int]: |
|
temperature = temperature or None |
|
do_sample = temperature is not None |
|
|
|
terminators = [self.tokenizer.eos_token_id] |
|
if "<|eot_id|>" in self.tokenizer.added_tokens_encoder: |
|
terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>")) |
|
|
|
input_len = model_inputs["input_ids"].shape[1] |
|
|
|
outputs = self.model.generate( |
|
**model_inputs, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
repetition_penalty=repetition_penalty, |
|
eos_token_id=terminators |
|
) |
|
return outputs[0][input_len:] |
|
|
|
def postprocess(self, model_outputs) -> str: |
|
output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True) |
|
return output_text |
|
|
|
|
|
transformers.pipelines.PIPELINE_REGISTRY.register_pipeline( |
|
"shuka-pipeline", |
|
pipeline_class=ShukaPipeline, |
|
pt_model=transformers.AutoModel, |
|
type="multimodal", |
|
) |
|
|