MtGPT2 / mtgpt2.py
MixoMax's picture
Upload mtgpt2.py
fd40de5 verified
raw
history blame
4.45 kB
# -*- coding: utf-8 -*-
"""MtGPT2
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1HMq9Cp_jhqc9HlUipLXi8SC1mHQhSgGn
"""
#@title Setup + Download Model
# Install and load dependencies
import locale
locale.getpreferredencoding = lambda: "UTF-8"
!pip install aitextgen
!pip install pytorch-lightning==1.7.0
import sys
from jinja2 import Template
from aitextgen import aitextgen
try:
from google.colab import files
except ImportError:
pass
# Download and load the model. Set to_gpu=False if running on a CPU.
ai = aitextgen(model="minimaxir/magic-the-gathering", to_gpu=True)
# This template is similar to Scryfall card formatting
TEMPLATE = Template(
"""{{ c.name }}{% if c.manaCost %} {{ c.manaCost }}{% endif %}
{{ c.type }}
{{ c.text }}{% if c.power %}
{{ c.power }}/{{ c.toughness }}{% endif %}{% if c.loyalty %}
Loyalty: {{ c.loyalty }}{% endif %}"""
)
def render_card(card_dict):
card = TEMPLATE.render(c=card_dict)
if card_dict["name"]:
card = card.replace("~", card_dict["name"])
return card
prompt="\u003C|type|>Creature - Human \u003C|name|>Rezo the Destroyer of Politics \u003C|manaCost|> 2 {B/G}" #@param {type:"string"}
temperature = 0.7 #@param {type:"slider", min:0.1, max:1.2, step:0.1}
to_file = False #@param {type:"boolean"}
n = 100 if to_file else 8
cards = ai.generate(n=n,
schema=True,
prompt=prompt,
temperature=temperature,
return_as_list=True)
cards = list(map(render_card, cards))
if to_file:
file_path = "cards.txt"
with open(file_path, "w", encoding="utf-8") as f:
for card in cards:
f.write("{}\n{}".format(card, "=" * 20 + "\n"))
if "google.colab" in sys.modules:
files.download(file_path)
else:
print(("\n" + "=" * 20 + "\n").join(cards))
def generate_cards(
n_cards: int = 8,
temperature: float = 0.75,
name: str = "",
manaCost: str = "",
type: str = "",
text: str = "",
power: str = "",
toughness: str = "",
loyalty: str = ""
):
#ensure n_cards is never 0 or negative
n_cards = int(n_cards)
if n_cards < 1:
n_cards = 1
#change manaCost from Format:
# 2UG
#to:
#{2}{U}{G}
manaCost_str = ""
for char in manaCost:
manaCost_str += "{"
manaCost_str += char
manaCost_str += "}"
prompt_str = ""
token_dict = {
"<|name|>": name,
"<|manaCost|>": manaCost_str,
"<|type|>": type,
"<|text|>": text,
"<|power|>": power,
"<|toughness|>": toughness,
"<|loyalty|>": loyalty
}
# Convert the token_dict into a formatted prompt string
for token, value in token_dict.items():
if value:
prompt_str += f"{token}{value}"
# Generate the cards using the prompt string and other parameters
cards = ai.generate(
n=n_cards,
schema=True,
prompt=prompt_str,
temperature=temperature,
return_as_list=True
)
cards = list(map(render_card, cards))
out_str = "\n=====\n".join(cards)
replacements = {
"{G}": "🌲",
"{U}": "🌊",
"{R}": "🔥",
"{B}": "💀",
"{W}": "☀️",
"{T}": "↩️",
#"{1}": "1⃣️",
#"{2}": "2⃣️",
#"{3}": "3⃣️",
#"{4}": "4⃣️",
#"{5}": "5⃣️",
#"{6}": "6⃣️",
#"{7}": "7⃣️",
#"{8}": "8⃣️",
#"{9}": "9⃣️",
}
for key, value in replacements.items():
out_str = out_str.replace(key, value)
return out_str
# Commented out IPython magic to ensure Python compatibility.
# %pip install gradio
import gradio as gr
iface = gr.Interface(
fn = generate_cards,
inputs=[
gr.Slider(minimum = 2, maximum=16, step=1, value=8),
gr.Slider(minimum = 0.1, maximum=1.5, step=0.01, value=0.75),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
],
outputs=gr.Textbox(),
title = "GPT-2 Powered MTG Card Generator",
description = "Enter Manacost as '2UG' for 2 colorless + Blue + Green mana. \n\n Temperature is recomended between 0.4 and 0.9. Anything above 1 will lead to random Chaos and very low values will just be boring."
)
iface.launch()