Spaces:
Runtime error
Runtime error
# Copyright 2022 Tristan Behrens. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Lint as: python3 | |
from fastapi import BackgroundTasks, FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from pydantic import BaseModel | |
from PIL import Image | |
import os | |
import io | |
import random | |
import base64 | |
from time import time | |
from statistics import mean | |
from collections import OrderedDict | |
import torch | |
import wave | |
from source.logging import create_logger | |
from source.tokensequence import token_sequence_to_audio, token_sequence_to_image | |
from source import constants | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
logger = create_logger(__name__) | |
# Load the auth-token from authtoken.txt. | |
auth_token = os.getenv("authtoken") | |
# Loading the model and its tokenizer. | |
logger.info("Loading tokenizer and model...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
"ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
"ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token | |
) | |
logger.info("Done.") | |
# Create the app | |
logger.info("Creating app...") | |
app = FastAPI(docs_url=None, redoc_url=None) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
logger.info("Done.") | |
class Options(BaseModel): | |
music_style: str | |
density: str | |
temperature: str | |
class NewTask(BaseModel): | |
music_style = "synth" | |
density = "medium" | |
temperature = "medium" | |
def get_place_in_queue(task_id): | |
queued_tasks = list( | |
task | |
for task in tasks.values() | |
if task["status"] == "queued" or task["status"] == "processing" | |
) | |
queued_tasks.sort(key=lambda task: task["created_at"]) | |
queued_task_ids = list(task["task_id"] for task in queued_tasks) | |
try: | |
return queued_task_ids.index(task_id) + 1 | |
except: | |
return 0 | |
def calculate_eta(task_id): | |
total_durations = list( | |
task["completed_at"] - task["started_at"] | |
for task in tasks.values() | |
if "completed_at" in task and task["status"] == "completed" | |
) | |
initial_place_in_queue = tasks[task_id]["initial_place_in_queue"] | |
if len(total_durations): | |
eta = initial_place_in_queue * mean(total_durations) | |
else: | |
eta = initial_place_in_queue * 35 | |
return round(eta, 1) | |
def next_task(task_id): | |
tasks[task_id]["completed_at"] = time() | |
queued_tasks = list(task for task in tasks.values() if task["status"] == "queued") | |
if queued_tasks: | |
print( | |
f"{task_id} {tasks[task_id]['status']}. Task/s remaining: {len(queued_tasks)}" | |
) | |
process_task(queued_tasks[0]["task_id"]) | |
def process_task(task_id): | |
if "processing" in list(task["status"] for task in tasks.values()): | |
return | |
if tasks[task_id]["last_poll"] and time() - tasks[task_id]["last_poll"] > 30: | |
tasks[task_id]["status"] = "abandoned" | |
next_task(task_id) | |
tasks[task_id]["status"] = "processing" | |
tasks[task_id]["started_at"] = time() | |
print(f"Processing {task_id}") | |
try: | |
tasks[task_id]["output"] = compose( | |
tasks[task_id]["music_style"], | |
tasks[task_id]["density"], | |
tasks[task_id]["temperature"], | |
) | |
except Exception as ex: | |
tasks[task_id]["status"] = "failed" | |
tasks[task_id]["error"] = repr(ex) | |
else: | |
tasks[task_id]["status"] = "completed" | |
finally: | |
next_task(task_id) | |
def compose(music_style, density, temperature): | |
instruments = constants.get_instruments(music_style) | |
density = constants.get_density(density) | |
temperature = constants.get_temperature(temperature) | |
print(f"instruments: {instruments} density: {density} temperature: {temperature}") | |
# Generate with the given parameters. | |
logger.info(f"Generating token sequence...") | |
generated_sequence = generate_sequence(instruments, density, temperature) | |
logger.info(f"Generated token sequence: {generated_sequence}") | |
# Get the audio data as a array of int16. | |
logger.info("Generating audio...") | |
sample_rate, audio_data = token_sequence_to_audio(generated_sequence) | |
logger.info(f"Done. Audio data: {len(audio_data)}") | |
# Encode the audio-data as wave file in memory. Use the wave module. | |
audio_data_bytes = io.BytesIO() | |
wave_file = wave.open(audio_data_bytes, "wb") | |
wave_file.setframerate(sample_rate) | |
wave_file.setnchannels(1) | |
wave_file.setsampwidth(2) | |
wave_file.writeframes(audio_data) | |
wave_file.close() | |
# Return the audio-data as a base64-encoded string. | |
audio_data_bytes.seek(0) | |
audio_data_base64 = base64.b64encode(audio_data_bytes.read()).decode("utf-8") | |
audio_data_bytes.close() | |
# Convert the audio data to an PIL image. | |
image = token_sequence_to_image(generated_sequence) | |
# Save PIL image to harddrive as PNG. | |
logger.debug(f"Saving image to harddrive... {type(image)}") | |
image_file_name = "compose.png" | |
image.save(image_file_name, "PNG") | |
# Save image to virtual file. | |
img_io = io.BytesIO() | |
image.save(img_io, "PNG", quality=70) | |
img_io.seek(0) | |
# Return the image as a base64-encoded string. | |
image_data_base64 = base64.b64encode(img_io.read()).decode("utf-8") | |
img_io.close() | |
# Return. | |
return { | |
"tokens": generated_sequence, | |
"audio": "data:audio/wav;base64," + audio_data_base64, | |
"image": "data:image/png;base64," + image_data_base64, | |
"status": "OK", | |
} | |
def generate_sequence(instruments, density, temperature): | |
instruments = instruments[::] | |
random.shuffle(instruments) | |
generated_ids = tokenizer.encode("PIECE_START", return_tensors="pt")[0] | |
for instrument in instruments: | |
more_ids = tokenizer.encode( | |
f"TRACK_START INST={instrument} DENSITY={density}", return_tensors="pt" | |
)[0] | |
generated_ids = torch.cat((generated_ids, more_ids)) | |
generated_ids = generated_ids.unsqueeze(0) | |
generated_ids = model.generate( | |
generated_ids, | |
max_length=2048, | |
do_sample=True, | |
temperature=temperature, | |
eos_token_id=tokenizer.encode("TRACK_END")[0], | |
)[0] | |
generated_sequence = tokenizer.decode(generated_ids) | |
print("GENERATING COMPLETE") | |
print(generate_sequence) | |
return generated_sequence | |
tasks = OrderedDict() | |
# Route for the loading page. | |
def index(request): | |
return FileResponse(path="static/index.html", media_type="text/html") | |
def create_task(background_tasks: BackgroundTasks, new_task: NewTask): | |
created_at = time() | |
task_id = f"{str(created_at)}_{new_task.music_style}" | |
tasks[task_id] = OrderedDict( | |
{ | |
"task_id": task_id, | |
"status": "queued", | |
"eta": None, | |
"created_at": created_at, | |
"started_at": None, | |
"completed_at": None, | |
"last_poll": None, | |
"poll_count": 0, | |
"initial_place_in_queue": None, | |
"place_in_queue": None, | |
"music_style": new_task.music_style, | |
"density": new_task.density, | |
"temperature": new_task.temperature, | |
"output": None, | |
} | |
) | |
tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id) | |
tasks[task_id]["eta"] = calculate_eta(task_id) | |
background_tasks.add_task(process_task, task_id) | |
return tasks[task_id] | |
def poll_task(task_id: str): | |
tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id) | |
tasks[task_id]["eta"] = calculate_eta(task_id) | |
tasks[task_id]["last_poll"] = time() | |
tasks[task_id]["poll_count"] += 1 | |
return tasks[task_id] | |