Spaces:
Running
on
A10G
Running
on
A10G
File size: 3,662 Bytes
96ea36d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import sys
sys.path.append('../AudioJourney')
import os
import yaml
import logging
import torchaudio
from APIs import WRITE_AUDIO, LOUDNESS_NORM
from utils import fade
from flask import Flask, request, jsonify
with open('config.yaml', 'r') as file:
config = yaml.safe_load(file)
# Configure the logging format and level
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# Create a FileHandler for the log file
os.makedirs('services_logs', exist_ok=True)
log_filename = 'services_logs/Text-to-Audio-Music.log'
file_handler = logging.FileHandler(log_filename, mode='w')
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
# Add the FileHandler to the root logger
logging.getLogger('').addHandler(file_handler)
# Initialize the model here
from audiocraft.models import AudioGen, MusicGen
tta_model = AudioGen.get_pretrained('facebook/audiogen-medium')
logging.info('AudioGen is loaded ...')
model_size = config['Text-to-Music']['model_size']
ttm_model = MusicGen.get_pretrained(f'facebook/musicgen-{model_size}')
logging.info(f'MusicGen ({model_size}) is loaded ...')
app = Flask(__name__)
@app.route('/generate_audio', methods=['POST'])
def generate_audio():
# Receive the text from the POST request
data = request.json
text = data['text']
length = float(data.get('length', 5.0))
volume = float(data.get('volume', -35))
output_wav = data.get('output_wav', 'out.wav')
logging.info(f'TTA (AudioGen): Prompt: {text}, length: {length} seconds, volume: {volume} dB')
try:
tta_model.set_generation_params(duration=length)
wav = tta_model.generate([text])
wav = torchaudio.functional.resample(wav, orig_freq=16000, new_freq=32000)
wav = wav.squeeze().cpu().detach().numpy()
wav = fade(LOUDNESS_NORM(wav, volumn=volume))
WRITE_AUDIO(wav, name=output_wav)
# Return success message and the filename of the generated audio
return jsonify({'message': f'Text-to-Audio generated successfully | {text}', 'file': output_wav})
except Exception as e:
return jsonify({'API error': str(e)}), 500
@app.route('/generate_music', methods=['POST'])
def generate_music():
# Receive the text from the POST request
data = request.json
text = data['text']
length = float(data.get('length', 5.0))
volume = float(data.get('volume', -35))
output_wav = data.get('output_wav', 'out.wav')
logging.info(f'TTM (MusicGen): Prompt: {text}, length: {length} seconds, volume: {volume} dB')
try:
ttm_model.set_generation_params(duration=length)
wav = ttm_model.generate([text])
wav = wav[0][0].cpu().detach().numpy()
wav = fade(LOUDNESS_NORM(wav, volumn=volume))
WRITE_AUDIO(wav, name=output_wav)
# Return success message and the filename of the generated audio
return jsonify({'message': f'Text-to-Music generated successfully | {text}', 'file': output_wav})
except Exception as e:
# Return error message if something goes wrong
return jsonify({'API error': str(e)}), 500
if __name__ == '__main__':
import yaml
with open('config.yaml', 'r') as file:
config = yaml.safe_load(file)
tta_service_port = config['Text-to-Audio']['service-port']
ttm_service_port = config['Text-to-Audio']['service-port']
if tta_service_port != ttm_service_port:
msg = 'Ports of TTA and TTM should be same if you are using Audiocraft ...'
logging.info(msg)
raise ValueError(msg)
app.run(debug=False, port=tta_service_port)
|