|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from fastapi import BackgroundTasks, FastAPI |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import FileResponse |
|
from pydantic import BaseModel |
|
from PIL import Image |
|
import os |
|
import io |
|
import random |
|
import base64 |
|
from time import time |
|
from statistics import mean |
|
from collections import OrderedDict |
|
import torch |
|
import wave |
|
from source.logging import create_logger |
|
from source.tokensequence import token_sequence_to_audio, token_sequence_to_image |
|
from source import constants |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
logger = create_logger(__name__) |
|
|
|
|
|
auth_token = os.getenv("authtoken") |
|
|
|
|
|
logger.info("Loading tokenizer and model...") |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token |
|
) |
|
logger.info("Done.") |
|
|
|
|
|
|
|
logger.info("Creating app...") |
|
app = FastAPI(docs_url=None, redoc_url=None) |
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
logger.info("Done.") |
|
|
|
|
|
class Options(BaseModel): |
|
music_style: str |
|
density: str |
|
temperature: str |
|
|
|
|
|
class NewTask(BaseModel): |
|
music_style = "synth" |
|
density = "medium" |
|
temperature = "medium" |
|
|
|
|
|
def get_place_in_queue(task_id): |
|
queued_tasks = list( |
|
task |
|
for task in tasks.values() |
|
if task["status"] == "queued" or task["status"] == "processing" |
|
) |
|
|
|
queued_tasks.sort(key=lambda task: task["created_at"]) |
|
|
|
queued_task_ids = list(task["task_id"] for task in queued_tasks) |
|
|
|
try: |
|
return queued_task_ids.index(task_id) + 1 |
|
except: |
|
return 0 |
|
|
|
|
|
def calculate_eta(task_id): |
|
total_durations = list( |
|
task["completed_at"] - task["started_at"] |
|
for task in tasks.values() |
|
if "completed_at" in task and task["status"] == "completed" |
|
) |
|
|
|
initial_place_in_queue = tasks[task_id]["initial_place_in_queue"] |
|
|
|
if len(total_durations): |
|
eta = initial_place_in_queue * mean(total_durations) |
|
else: |
|
eta = initial_place_in_queue * 35 |
|
|
|
return round(eta, 1) |
|
|
|
|
|
def next_task(task_id): |
|
tasks[task_id]["completed_at"] = time() |
|
|
|
queued_tasks = list(task for task in tasks.values() if task["status"] == "queued") |
|
|
|
if queued_tasks: |
|
print( |
|
f"{task_id} {tasks[task_id]['status']}. Task/s remaining: {len(queued_tasks)}" |
|
) |
|
process_task(queued_tasks[0]["task_id"]) |
|
|
|
|
|
def process_task(task_id): |
|
if "processing" in list(task["status"] for task in tasks.values()): |
|
return |
|
|
|
if tasks[task_id]["last_poll"] and time() - tasks[task_id]["last_poll"] > 30: |
|
tasks[task_id]["status"] = "abandoned" |
|
next_task(task_id) |
|
|
|
tasks[task_id]["status"] = "processing" |
|
tasks[task_id]["started_at"] = time() |
|
print(f"Processing {task_id}") |
|
|
|
try: |
|
tasks[task_id]["output"] = compose( |
|
tasks[task_id]["music_style"], |
|
tasks[task_id]["density"], |
|
tasks[task_id]["temperature"], |
|
) |
|
except Exception as ex: |
|
tasks[task_id]["status"] = "failed" |
|
tasks[task_id]["error"] = repr(ex) |
|
else: |
|
tasks[task_id]["status"] = "completed" |
|
finally: |
|
next_task(task_id) |
|
|
|
|
|
def compose(music_style, density, temperature): |
|
instruments = constants.get_instruments(music_style) |
|
density = constants.get_density(density) |
|
temperature = constants.get_temperature(temperature) |
|
print(f"instruments: {instruments} density: {density} temperature: {temperature}") |
|
|
|
|
|
logger.info(f"Generating token sequence...") |
|
generated_sequence = generate_sequence(instruments, density, temperature) |
|
logger.info(f"Generated token sequence: {generated_sequence}") |
|
|
|
|
|
logger.info("Generating audio...") |
|
sample_rate, audio_data = token_sequence_to_audio(generated_sequence) |
|
logger.info(f"Done. Audio data: {len(audio_data)}") |
|
|
|
|
|
audio_data_bytes = io.BytesIO() |
|
wave_file = wave.open(audio_data_bytes, "wb") |
|
wave_file.setframerate(sample_rate) |
|
wave_file.setnchannels(1) |
|
wave_file.setsampwidth(2) |
|
wave_file.writeframes(audio_data) |
|
wave_file.close() |
|
|
|
|
|
audio_data_bytes.seek(0) |
|
audio_data_base64 = base64.b64encode(audio_data_bytes.read()).decode("utf-8") |
|
audio_data_bytes.close() |
|
|
|
|
|
image = token_sequence_to_image(generated_sequence) |
|
|
|
|
|
logger.debug(f"Saving image to harddrive... {type(image)}") |
|
image_file_name = "compose.png" |
|
image.save(image_file_name, "PNG") |
|
|
|
|
|
img_io = io.BytesIO() |
|
image.save(img_io, "PNG", quality=70) |
|
img_io.seek(0) |
|
|
|
|
|
image_data_base64 = base64.b64encode(img_io.read()).decode("utf-8") |
|
img_io.close() |
|
|
|
|
|
return { |
|
"tokens": generated_sequence, |
|
"audio": "data:audio/wav;base64," + audio_data_base64, |
|
"image": "data:image/png;base64," + image_data_base64, |
|
"status": "OK", |
|
} |
|
|
|
|
|
def generate_sequence(instruments, density, temperature): |
|
instruments = instruments[::] |
|
random.shuffle(instruments) |
|
|
|
generated_ids = tokenizer.encode("PIECE_START", return_tensors="pt")[0] |
|
|
|
for instrument in instruments: |
|
more_ids = tokenizer.encode( |
|
f"TRACK_START INST={instrument} DENSITY={density}", return_tensors="pt" |
|
)[0] |
|
generated_ids = torch.cat((generated_ids, more_ids)) |
|
generated_ids = generated_ids.unsqueeze(0) |
|
|
|
generated_ids = model.generate( |
|
generated_ids, |
|
max_length=2048, |
|
do_sample=True, |
|
temperature=temperature, |
|
eos_token_id=tokenizer.encode("TRACK_END")[0], |
|
)[0] |
|
|
|
generated_sequence = tokenizer.decode(generated_ids) |
|
print("GENERATING COMPLETE") |
|
print(generate_sequence) |
|
return generated_sequence |
|
|
|
|
|
tasks = OrderedDict() |
|
|
|
|
|
@app.head("/") |
|
@app.route("/") |
|
def index(request): |
|
return FileResponse(path="static/index.html", media_type="text/html") |
|
|
|
|
|
@app.post("/task/create") |
|
def create_task(background_tasks: BackgroundTasks, new_task: NewTask): |
|
created_at = time() |
|
|
|
task_id = f"{str(created_at)}_{new_task.music_style}" |
|
|
|
tasks[task_id] = OrderedDict( |
|
{ |
|
"task_id": task_id, |
|
"status": "queued", |
|
"eta": None, |
|
"created_at": created_at, |
|
"started_at": None, |
|
"completed_at": None, |
|
"last_poll": None, |
|
"poll_count": 0, |
|
"initial_place_in_queue": None, |
|
"place_in_queue": None, |
|
"music_style": new_task.music_style, |
|
"density": new_task.density, |
|
"temperature": new_task.temperature, |
|
"output": None, |
|
} |
|
) |
|
|
|
tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id) |
|
tasks[task_id]["eta"] = calculate_eta(task_id) |
|
|
|
background_tasks.add_task(process_task, task_id) |
|
|
|
return tasks[task_id] |
|
|
|
|
|
@app.get("/task/poll") |
|
def poll_task(task_id: str): |
|
tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id) |
|
tasks[task_id]["eta"] = calculate_eta(task_id) |
|
tasks[task_id]["last_poll"] = time() |
|
tasks[task_id]["poll_count"] += 1 |
|
|
|
return tasks[task_id] |
|
|