File size: 1,260 Bytes
d6f5b9c e1735f5 d6f5b9c e1735f5 d6f5b9c 1c28486 d6f5b9c 559b8b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from typing import Any, Dict
import torch
from transformers import AutoModel, AutoProcessor
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = AutoProcessor.from_pretrained("suno/bark-small")
self.model = AutoModel.from_pretrained(
"suno/bark-small",
).to("cuda")
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
text = data.pop("inputs", data)
voice_preset = data.get("voice_preset", None)
if voice_preset:
inputs = self.processor(
text=[text],
return_tensors="pt",
voice_preset=voice_preset,
).to("cuda")
else:
inputs = self.processor(
text=[text],
return_tensors="pt",
).to("cuda")
with torch.autocast("cuda"):
outputs = self.model.generate(**inputs)
# postprocess the prediction
prediction = outputs.cpu().numpy().tolist()
return {"generated_audio": prediction}
|