|
import gradio as gr |
|
import torch |
|
from bigram_model import BigramLanguageModel, encode, decode |
|
|
|
|
|
|
|
class GradioInterface: |
|
def __init__(self, model_path=None): |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
self.model = self.load_model(model_path) |
|
self.model.eval() |
|
|
|
def load_model(self, model_path): |
|
model = BigramLanguageModel().to(self.device) |
|
if model_path: |
|
model.load_state_dict(torch.load(model_path, map_location=self.device)) |
|
return model |
|
|
|
def generate_text(self, input_text, max_tokens=100): |
|
context = torch.tensor([encode(input_text)], dtype=torch.long, device=self.device) |
|
output = self.model.generate(context, max_new_tokens=max_tokens) |
|
return decode(output[0].tolist()) |
|
|
|
|
|
model_path = "models/lafontaine_gpt_v8_241011_1307.pth" |
|
model_interface = GradioInterface(model_path) |
|
|
|
|
|
gr_interface = gr.Interface( |
|
fn=model_interface.generate_text, |
|
inputs=["text", gr.Slider(50, 500)], |
|
outputs="text", |
|
description="Bigram Language Model text generation. Enter some text, and the model will continue it.", |
|
examples=[["Once upon a time"]] |
|
) |
|
|
|
|
|
gr_interface.launch() |
|
|