# Copyright 2022 Tristan Behrens. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 from flask import Flask, render_template, request, send_file, jsonify, redirect, url_for from PIL import Image import os import io import random import base64 import torch import wave from source.logging import create_logger from source.tokensequence import token_sequence_to_audio, token_sequence_to_image from source import constants from transformers import AutoTokenizer, AutoModelForCausalLM logger = create_logger(__name__) # Load the auth-token from authtoken.txt. auth_token = os.getenv("authtoken") # Loading the model and its tokenizer. logger.info("Loading tokenizer and model...") tokenizer = AutoTokenizer.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token) model = AutoModelForCausalLM.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token) logger.info("Done.") # Create the app. logger.info("Creating app...") app = Flask(__name__) logger.info("Done.") # Route for the loading page. @app.route("/") def index(): return render_template( "index.html", compose_styles=constants.get_compose_styles_for_ui(), densities=constants.get_densities_for_ui(), temperatures=constants.get_temperatures_for_ui(), ) @app.route("/compose", methods=["POST"]) def compose(): # Get the parameters as JSON. params = request.get_json() music_style = params["music_style"] density = params["density"] temperature = params["temperature"] instruments = constants.get_instruments(music_style) density = constants.get_density(density) temperature = constants.get_temperature(temperature) print(f"instruments: {instruments} density: {density} temperature: {temperature}") # Generate with the given parameters. logger.info(f"Generating token sequence...") generated_sequence = generate_sequence(instruments, density, temperature) logger.info(f"Generated token sequence: {generated_sequence}") # Get the audio data as a array of int16. logger.info("Generating audio...") sample_rate, audio_data = token_sequence_to_audio(generated_sequence) logger.info(f"Done. Audio data: {len(audio_data)}") # Encode the audio-data as wave file in memory. Use the wave module. audio_data_bytes = io.BytesIO() wave_file = wave.open(audio_data_bytes, "wb") wave_file.setframerate(sample_rate) wave_file.setnchannels(1) wave_file.setsampwidth(2) wave_file.writeframes(audio_data) wave_file.close() # Return the audio-data as a base64-encoded string. audio_data_bytes.seek(0) audio_data_base64 = base64.b64encode(audio_data_bytes.read()).decode("utf-8") audio_data_bytes.close() # Convert the audio data to an PIL image. image = token_sequence_to_image(generated_sequence) # Save PIL image to harddrive as PNG. logger.debug(f"Saving image to harddrive... {type(image)}") image_file_name = "compose.png" image.save(image_file_name, "PNG") # Save image to virtual file. img_io = io.BytesIO() image.save(img_io, "PNG", quality=70) img_io.seek(0) # Return the image as a base64-encoded string. image_data_base64 = base64.b64encode(img_io.read()).decode("utf-8") img_io.close() # Return. return jsonify({ "tokens": generated_sequence, "audio": "data:audio/wav;base64," + audio_data_base64, "image": "data:image/png;base64," + image_data_base64, "status": "OK" }) def generate_sequence(instruments, density, temperature): instruments = instruments[::] random.shuffle(instruments) generated_ids = tokenizer.encode("PIECE_START", return_tensors="pt")[0] for instrument in instruments: more_ids = tokenizer.encode(f"TRACK_START INST={instrument} DENSITY={density}", return_tensors="pt")[0] generated_ids = torch.cat((generated_ids, more_ids)) generated_ids = generated_ids.unsqueeze(0) generated_ids = model.generate( generated_ids, max_length=2048, do_sample=True, temperature=temperature, eos_token_id=tokenizer.encode("TRACK_END")[0] )[0] generated_sequence = tokenizer.decode(generated_ids) return generated_sequence if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)