File size: 1,981 Bytes
5c239ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c519fd
5c239ba
 
4c519fd
 
5c239ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c519fd
5c239ba
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
print("Preparing for inference...")  # noqa

from rudalle.pipelines import generate_images
from rudalle import get_rudalle_model, get_tokenizer, get_vae
from huggingface_hub import hf_hub_url, cached_download
import torch
from io import BytesIO
import base64

print(f"GPUs available: {torch.cuda.device_count()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
fp16 = torch.cuda.is_available()

file_dir = "./models"
file_name = "pytorch_model.bin"
config_file_url = hf_hub_url(repo_id="minimaxir/ai-generated-pokemon-rudalle", filename=file_name)
cached_download(config_file_url, cache_dir=file_dir, force_filename=file_name)

model = get_rudalle_model('Malevich', pretrained=False, fp16=fp16, device=device)
model.load_state_dict(torch.load(f"{file_dir}/{file_name}", map_location=f"{'cuda:0' if torch.cuda.is_available() else 'cpu'}"))

vae = get_vae().to(device)
tokenizer = get_tokenizer()

print("Ready for inference")


def english_to_russian(english_string):
    word_map = {
        "grass": "трава",
        "fire": "Пожар",
        "water": "вода",
        "lightning": "молния",
        "fighting": "борьба",
        "psychic": "психический",
        "colorless": "бесцветный",
        "darkness": "темнота",
        "metal": "металл",
        "dragon": "Дракон",
        "fairy": "сказочный"
    }

    return word_map[english_string.lower()]


def generate_image(prompt):
    if prompt.lower() in ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness', 'metal', 'dragon', 'fairy']:
        prompt = english_to_russian(prompt)

    result, _ = generate_images(prompt, tokenizer, model, vae, top_k=2048, images_num=1, top_p=0.995)

    buffer = BytesIO()
    result[0].save(buffer, format="PNG")
    base64_bytes = base64.b64encode(buffer.getvalue())
    base64_string = base64_bytes.decode("UTF-8")

    return "data:image/png;base64," + base64_string