import gradio as gr import torch from bigram_model import BigramLanguageModel, encode, decode # Assuming 'BigramLanguageModel' and 'decode' are defined as in your code class GradioInterface: def __init__(self, model_path=None): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model = self.load_model(model_path) self.model.eval() def load_model(self, model_path): model = BigramLanguageModel().to(self.device) if model_path: model.load_state_dict(torch.load(model_path, map_location=self.device)) return model def generate_text(self, input_text, max_tokens=100): context = torch.tensor([encode(input_text)], dtype=torch.long, device=self.device) output = self.model.generate(context, max_new_tokens=max_tokens) return decode(output[0].tolist()) # Load the model model_path = "models/lafontaine_gpt_v8_241011_1307.pth" model_interface = GradioInterface(model_path) # Define Gradio interface gr_interface = gr.Interface( fn=model_interface.generate_text, inputs=["text", gr.Slider(50, 500)], outputs="text", description="Bigram Language Model text generation. Enter some text, and the model will continue it.", examples=[["Once upon a time"]] ) # Launch the interface gr_interface.launch()