|
import gradio as gr |
|
import torch |
|
from gpt_dev import GPTLanguageModel, encode, decode, generate_text |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
block_size = 256 |
|
n_embd = 384 |
|
n_head = 6 |
|
n_layer = 6 |
|
vocab_size = 95 |
|
|
|
|
|
model = GPTLanguageModel() |
|
model.to(device) |
|
|
|
|
|
checkpoint = torch.load("gpt_language_model.pth", map_location=device) |
|
model.load_state_dict(checkpoint) |
|
model.eval() |
|
|
|
|
|
def generate_response(prompt, max_length=100, temperature=1.0): |
|
generated_text = generate_text(model, prompt, max_length=max_length, temperature=temperature) |
|
return generated_text |
|
|
|
|
|
def gradio_interface(prompt, max_length=100, temperature=1.0): |
|
return generate_response(prompt, max_length, temperature) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=[ |
|
gr.Textbox(label="Prompt", value="Once upon a time"), |
|
gr.Slider(50, 240, step=1, value=75, label="Max Length"), |
|
], |
|
outputs="text", |
|
title="Odeyssey Rhyme Generator", |
|
description="Enter a prompt to generate text." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch(share=True) |
|
|