Update handler.py
Browse files- handler.py +9 -4
handler.py
CHANGED
@@ -3,7 +3,7 @@ from tangoflux import TangoFluxInference
|
|
3 |
import torchaudio
|
4 |
|
5 |
from huggingface_inference_toolkit.logging import logger
|
6 |
-
|
7 |
|
8 |
class EndpointHandler():
|
9 |
def __init__(self, path=""):
|
@@ -40,7 +40,12 @@ class EndpointHandler():
|
|
40 |
duration = parameters.get("duration", 10)
|
41 |
guidance_scale = parameters.get("guidance_scale", 3.5)
|
42 |
|
43 |
-
|
44 |
-
return self.model.generate(prompt,steps=num_inference_steps,
|
45 |
duration=duration,
|
46 |
-
guidance_scale=guidance_scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import torchaudio
|
4 |
|
5 |
from huggingface_inference_toolkit.logging import logger
|
6 |
+
import io
|
7 |
|
8 |
class EndpointHandler():
|
9 |
def __init__(self, path=""):
|
|
|
40 |
duration = parameters.get("duration", 10)
|
41 |
guidance_scale = parameters.get("guidance_scale", 3.5)
|
42 |
|
43 |
+
audio= self.model.generate(prompt,steps=num_inference_steps,
|
|
|
44 |
duration=duration,
|
45 |
+
guidance_scale=guidance_scale)
|
46 |
+
|
47 |
+
buffer = io.BytesIO()
|
48 |
+
torchaudio.save(buffer, audio, 44100, format="wav")
|
49 |
+
buffer.seek(0)
|
50 |
+
|
51 |
+
return buffer.read()
|