|
|
|
"""MtGPT2 |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1HMq9Cp_jhqc9HlUipLXi8SC1mHQhSgGn |
|
""" |
|
|
|
|
|
|
|
import locale |
|
locale.getpreferredencoding = lambda: "UTF-8" |
|
|
|
!pip install aitextgen |
|
!pip install pytorch-lightning==1.7.0 |
|
|
|
import sys |
|
from jinja2 import Template |
|
from aitextgen import aitextgen |
|
|
|
try: |
|
from google.colab import files |
|
except ImportError: |
|
pass |
|
|
|
|
|
ai = aitextgen(model="minimaxir/magic-the-gathering", to_gpu=True) |
|
|
|
|
|
TEMPLATE = Template( |
|
"""{{ c.name }}{% if c.manaCost %} {{ c.manaCost }}{% endif %} |
|
{{ c.type }} |
|
{{ c.text }}{% if c.power %} |
|
{{ c.power }}/{{ c.toughness }}{% endif %}{% if c.loyalty %} |
|
Loyalty: {{ c.loyalty }}{% endif %}""" |
|
) |
|
|
|
def render_card(card_dict): |
|
card = TEMPLATE.render(c=card_dict) |
|
if card_dict["name"]: |
|
card = card.replace("~", card_dict["name"]) |
|
return card |
|
|
|
prompt="\u003C|type|>Creature - Human \u003C|name|>Rezo the Destroyer of Politics \u003C|manaCost|> 2 {B/G}" |
|
temperature = 0.7 |
|
to_file = False |
|
|
|
n = 100 if to_file else 8 |
|
|
|
cards = ai.generate(n=n, |
|
schema=True, |
|
prompt=prompt, |
|
temperature=temperature, |
|
return_as_list=True) |
|
|
|
cards = list(map(render_card, cards)) |
|
|
|
if to_file: |
|
file_path = "cards.txt" |
|
with open(file_path, "w", encoding="utf-8") as f: |
|
for card in cards: |
|
f.write("{}\n{}".format(card, "=" * 20 + "\n")) |
|
if "google.colab" in sys.modules: |
|
files.download(file_path) |
|
else: |
|
print(("\n" + "=" * 20 + "\n").join(cards)) |
|
|
|
def generate_cards( |
|
n_cards: int = 8, |
|
temperature: float = 0.75, |
|
name: str = "", |
|
manaCost: str = "", |
|
type: str = "", |
|
text: str = "", |
|
power: str = "", |
|
toughness: str = "", |
|
loyalty: str = "" |
|
): |
|
|
|
n_cards = int(n_cards) |
|
if n_cards < 1: |
|
n_cards = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
manaCost_str = "" |
|
|
|
for char in manaCost: |
|
manaCost_str += "{" |
|
manaCost_str += char |
|
manaCost_str += "}" |
|
|
|
|
|
|
|
|
|
prompt_str = "" |
|
|
|
token_dict = { |
|
"<|name|>": name, |
|
"<|manaCost|>": manaCost_str, |
|
"<|type|>": type, |
|
"<|text|>": text, |
|
"<|power|>": power, |
|
"<|toughness|>": toughness, |
|
"<|loyalty|>": loyalty |
|
} |
|
|
|
|
|
for token, value in token_dict.items(): |
|
if value: |
|
prompt_str += f"{token}{value}" |
|
|
|
|
|
cards = ai.generate( |
|
n=n_cards, |
|
schema=True, |
|
prompt=prompt_str, |
|
temperature=temperature, |
|
return_as_list=True |
|
) |
|
|
|
cards = list(map(render_card, cards)) |
|
|
|
out_str = "\n=====\n".join(cards) |
|
|
|
replacements = { |
|
"{G}": "🌲", |
|
"{U}": "🌊", |
|
"{R}": "🔥", |
|
"{B}": "💀", |
|
"{W}": "☀️", |
|
"{T}": "↩️", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
for key, value in replacements.items(): |
|
out_str = out_str.replace(key, value) |
|
|
|
|
|
return out_str |
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
iface = gr.Interface( |
|
fn = generate_cards, |
|
inputs=[ |
|
gr.Slider(minimum = 2, maximum=16, step=1, value=8), |
|
gr.Slider(minimum = 0.1, maximum=1.5, step=0.01, value=0.75), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
gr.Textbox(), |
|
], |
|
outputs=gr.Textbox(), |
|
title = "GPT-2 Powered MTG Card Generator", |
|
description = "Enter Manacost as '2UG' for 2 colorless + Blue + Green mana. \n\n Temperature is recomended between 0.4 and 0.9. Anything above 1 will lead to random Chaos and very low values will just be boring." |
|
) |
|
iface.launch() |