MtGPT2 / app.py
MixoMax's picture
Update app.py
8c528e0 verified
raw
history blame
4.24 kB
# -*- coding: utf-8 -*-
"""MtGPT2
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1HMq9Cp_jhqc9HlUipLXi8SC1mHQhSgGn
"""
#@title Setup + Download Model
# Install and load dependencies
import os
os.system("python3 -m pip install pytorch-lightning==1.7.0 aitextgen")
with open("/home/user/.local/lib/python3.10/site-packages/pytorch_lightning/callbacks/progress/rich_progress.py", "r") as file:
lines = file.readlines()
#replace line 20 "from torchmetrics.utilities.imports import _compare_version"
#with
#"from lightning_utilities.core.imports import compare_version as _compare_version"
lines[19] = "from lightning_utilities.core.imports import compare_version as _compare_version\n"
with open("/home/user/.local/lib/python3.10/site-packages/pytorch_lightning/callbacks/progress/rich_progress.py", "w") as file:
file.writelines(lines)
import sys
from jinja2 import Template
from aitextgen import aitextgen
try:
from google.colab import files
except ImportError:
pass
# Download and load the model. Set to_gpu=False if running on a CPU.
ai = aitextgen(model="minimaxir/magic-the-gathering", to_gpu=False)
# This template is similar to Scryfall card formatting
TEMPLATE = Template(
"""{{ c.name }}{% if c.manaCost %} {{ c.manaCost }}{% endif %}
{{ c.type }}
{{ c.text }}{% if c.power %}
{{ c.power }}/{{ c.toughness }}{% endif %}{% if c.loyalty %}
Loyalty: {{ c.loyalty }}{% endif %}"""
)
def render_card(card_dict):
card = TEMPLATE.render(c=card_dict)
if card_dict["name"]:
card = card.replace("~", card_dict["name"])
return card
def generate_cards(
n_cards: int = 8,
temperature: float = 0.75,
name: str = "",
manaCost: str = "",
type: str = "",
text: str = "",
power: str = "",
toughness: str = "",
loyalty: str = ""
):
#ensure n_cards is never 0 or negative
n_cards = int(n_cards)
if n_cards < 1:
n_cards = 1
#change manaCost from Format:
# 2UG
#to:
#{2}{U}{G}
manaCost_str = ""
for char in manaCost:
manaCost_str += "{"
manaCost_str += char
manaCost_str += "}"
prompt_str = ""
token_dict = {
"<|name|>": name,
"<|manaCost|>": manaCost_str,
"<|type|>": type,
"<|text|>": text,
"<|power|>": power,
"<|toughness|>": toughness,
"<|loyalty|>": loyalty
}
# Convert the token_dict into a formatted prompt string
for token, value in token_dict.items():
if value:
prompt_str += f"{token}{value}"
# Generate the cards using the prompt string and other parameters
cards = ai.generate(
n=n_cards,
schema=True,
prompt=prompt_str,
temperature=temperature,
return_as_list=True
)
cards = list(map(render_card, cards))
out_str = "\n=====\n".join(cards)
replacements = {
"{G}": "🌲",
"{U}": "🌊",
"{R}": "πŸ”₯",
"{B}": "πŸ’€",
"{W}": "β˜€οΈ",
"{T}": "↩️",
#"{1}": "(1)",
#"{2}": "(2)",
#"{3}": "(3)",
#"{4}": "(4)",
#"{5}": "(5)",
#"{6}": "(6)",
#"{7}": "(7)",
#"{8}": "(8)",
#"{9}": "(9)",
}
for key, value in replacements.items():
out_str = out_str.replace(key, value)
return out_str
# Commented out IPython magic to ensure Python compatibility.
# %pip install gradio
import gradio as gr
iface = gr.Interface(
fn = generate_cards,
inputs=[
gr.Slider(minimum = 2, maximum=16, step=1, value=8),
gr.Slider(minimum = 0.1, maximum=1.5, step=0.01, value=0.75),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
],
outputs=gr.Textbox(),
title = "GPT-2 Powered MTG Card Generator",
description = "Enter Manacost as '2UG' for 2 colorless + Blue + Green mana. \n\n Temperature is recomended between 0.4 and 0.9. Anything above 1 will lead to random Chaos and very low values will just be boring.",
)
iface.launch(allow_flagging = "never", show_api = True)