import torch import tiktoken from model import GPT, GPTConfig import gradio as gr from torch.nn import functional as F device = "cpu" if torch.cuda.is_available(): device = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" # STOP num_return_sequences = 1 # max_length = 100 model = GPT(GPTConfig()) model.to(device) model.load_state_dict(torch.load('./checkpoints/final_model.pth', map_location=device)) # Set the model to evaluation mode model.eval() def generate(text, max_length): enc = tiktoken.get_encoding("gpt2") tokens = enc.encode(text) tokens = torch.tensor(tokens, dtype= torch.long) # (len,) #check tiktoken app tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1) # (1, len) x = tokens.to(device) while x.size(1) < max_length: # forward the model to get the logits with torch.no_grad(): logits = model(x)[0] # (B, T, vocab_size) # take the logits at the last position logits = logits[:, -1, :] # (B, vocab_size) # get the probabilities probs = F.softmax(logits, dim=-1) # do top-k sampling of 50 (huggingface pipeline default) # topk_probs here becomes (5, 50), topk_indices is (5, 50) topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # select a token from the top-k probabilities # note: multinomial does not demand the input to sum to 1 ix = torch.multinomial(topk_probs, 1) # (B, 1) # gather the corresponding indices xcol = torch.gather(topk_indices, -1, ix) # (B, 1) # append to the sequence x = torch.cat((x, xcol), dim=1) # print the generated text for i in range(num_return_sequences): tokens = x[i, :max_length].tolist() decoded = enc.decode(tokens) return decoded title = "Shakespeare Poem generation using GPT - 121M Model." description = "A simple Gradio interface to demo genaration of shakespeare poem." examples = [["Let us kill him, and we'll have corn at our own price."], ["Would you proceed especially against Caius Marcius?"], ["Nay, but speak not maliciously."]], demo = gr.Interface( generate, inputs=[ gr.TextArea(label="Enter text"), gr.Slider(10, 100, value = 10, step=1, label="Token Length"), ], outputs=[ gr.TextArea(label="Generated Text") ], title=title, description=description, examples=examples, cache_examples=False, live=True ) demo.launch()