|
## MERaLiON-AudioLLM vLLM Serving |
|
|
|
> [!IMPORTANT] |
|
> MERaLiON-AudioLLM is trained on 30s audios. This vllm integration supports at most 4mins audio input. |
|
|
|
### Set up Environment |
|
|
|
MERaLiON-AudioLLM requires vLLM version `0.6.4.post1` and transformers `4.46.3` |
|
|
|
```bash |
|
pip install vllm==0.6.4.post1 |
|
pip install transformers==4.46.3 |
|
``` |
|
|
|
As the [vLLM documentation](https://docs.vllm.ai/en/stable/models/adding_model.html#out-of-tree-model-integration) recommends, |
|
we provide a way to register our model via [vLLM plugins](https://docs.vllm.ai/en/stable/design/plugin_system.html#plugin-system). |
|
|
|
|
|
```bash |
|
python install . |
|
``` |
|
|
|
|
|
### Offline Inference |
|
|
|
Here is an example of offline inference using our custom vLLM class. |
|
|
|
```python |
|
import torch |
|
from vllm import ModelRegistry, LLM, SamplingParams |
|
from vllm.assets.audio import AudioAsset |
|
|
|
def run_meralion(question: str): |
|
model_name = "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION" |
|
|
|
llm = LLM(model=model_name, |
|
tokenizer=model_name, |
|
max_num_seqs=8, |
|
limit_mm_per_prompt={"audio": 1}, |
|
trust_remote_code=True, |
|
dtype=torch.bfloat16 |
|
) |
|
|
|
audio_in_prompt = "Given the following audio context: <SpeechHere>\n\n" |
|
|
|
prompt = ("<start_of_turn>user\n" |
|
f"{audio_in_prompt}Text instruction: {question}<end_of_turn>\n" |
|
"<start_of_turn>model\n") |
|
stop_token_ids = None |
|
return llm, prompt, stop_token_ids |
|
|
|
audio_asset = AudioAsset("mary_had_lamb") |
|
question= "Please trancribe this speech." |
|
|
|
llm, prompt, stop_token_ids = run_meralion(question) |
|
|
|
# We set temperature to 0.2 so that outputs can be different |
|
# even when all prompts are identical when running batch inference. |
|
sampling_params = SamplingParams( |
|
temperature=0.1, |
|
top_p=0.9, |
|
top_k=50, |
|
repetition_penalty=1.1, |
|
seed=42, |
|
max_tokens=1024, |
|
stop_token_ids=None |
|
) |
|
|
|
mm_data = {"audio": [audio_asset.audio_and_sample_rate]} |
|
inputs = {"prompt": prompt, "multi_modal_data": mm_data} |
|
|
|
# batch inference |
|
inputs = [inputs] * 2 |
|
|
|
outputs = llm.generate(inputs, sampling_params=sampling_params) |
|
|
|
for o in outputs: |
|
generated_text = o.outputs[0].text |
|
print(generated_text) |
|
``` |
|
|
|
### Serving |
|
|
|
Here is an example to start the server via the `vllm serve` command. |
|
|
|
```bash |
|
export HF_TOKEN=<your-hf-token> |
|
|
|
vllm serve MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --tokenizer MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION --max-num-seqs 8 --trust-remote-code --dtype bfloat16 --port 8000 |
|
``` |
|
|
|
To call the server, you can use the [official OpenAI client](https://github.com/openai/openai-python): |
|
|
|
```python |
|
import base64 |
|
|
|
from openai import OpenAI |
|
|
|
|
|
def get_client(api_key="EMPTY", base_url="http://localhost:8000/v1"): |
|
client = OpenAI( |
|
api_key=api_key, |
|
base_url=base_url, |
|
) |
|
|
|
models = client.models.list() |
|
model_name = models.data[0].id |
|
return client, model_name |
|
|
|
|
|
def get_response(text_input, base64_audio_input, **params): |
|
response_obj = client.chat.completions.create( |
|
messages=[{ |
|
"role": |
|
"user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": f"Text instruction: {text_input}" |
|
}, |
|
{ |
|
"type": "audio_url", |
|
"audio_url": { |
|
"url": f"data:audio/ogg;base64,{base64_audio_input}" |
|
}, |
|
}, |
|
], |
|
}], |
|
**params |
|
) |
|
return response_obj |
|
|
|
|
|
#specify input and params |
|
possible_text_inputs = [ |
|
"Please transcribe this speech.", |
|
"Please summarise the content of this speech.", |
|
"Please follow the instruction in this speech." |
|
] |
|
|
|
audio_bytes = open(f"/path/to/wav/or/mp3/file", "rb").read() |
|
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') |
|
|
|
# use the port number of vllm service. |
|
client, model_name = get_client(base_url="http://localhost:8000/v1") |
|
|
|
generation_parameters = dict( |
|
model=model_name, |
|
max_completion_tokens=1024, |
|
temperature=0.1, |
|
top_p=0.9, |
|
extra_body={ |
|
"repetition_penalty": 1.1, |
|
"top_k": 50, |
|
"length_penalty": 1.0 |
|
}, |
|
seed=42 |
|
) |
|
|
|
|
|
response_obj = get_response(possible_text_inputs[0], audio_base64, **generation_parameters) |
|
print(response_obj.choices[0].message.content) |
|
``` |
|
|
|
Alternatively, you can try calling the server with curl, see the example below. We recommend using the generation config in the json body to fully reproduce the performance. |
|
|
|
```bash |
|
curl http://localhost:8000/v1/chat/completions \ |
|
-H "Content-Type: application/json" \ |
|
-d '{ |
|
"model": "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION", |
|
"messages": [ |
|
{"role": "user", |
|
"content": [ |
|
{"type": "text", "text": "Text instruction: <your-instruction>"}, |
|
{"type": "audio_url", "audio_url": {"url": "data:audio/ogg;base64,<your-audio-base64-string>"}} |
|
] |
|
} |
|
], |
|
"max_completion_tokens": 1024, |
|
"temperature": 0.1, |
|
"top_p": 0.9, |
|
"seed": 42 |
|
}' |
|
``` |