rubber-duck / codegen.py
j10sanders's picture
Upload 3 files
5569fcb
raw
history blame
1.24 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
DEVICE = 'cpu'
TOKENIZER = None
MODEL = None
def setup(model: str, setup_torch: bool = False):
global TOKENIZER, MODEL, DEVICE
if setup_torch:
try:
import torch
torch.set_default_tensor_type(torch.cuda.FloatTensor)
# Make sure that we're using CPU when GPU isn't available
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
except:
print("ERROR: Can't set default tensor type to FloatTensor")
TOKENIZER = AutoTokenizer.from_pretrained(model)
MODEL = AutoModelForCausalLM.from_pretrained(model)
def generate(token: str) -> str:
"""
Generate some code using the loaded model given some input.
:param token: The input that will be passed into the tokenizer for the model to generate an output with.
:return: The string output.
"""
if TOKENIZER is None or MODEL is None:
raise Exception("Model and tokenizer has not been setup.")
inputs = TOKENIZER(token, return_tensors='pt').to(DEVICE)
sample = MODEL.generate(**inputs, max_length=128)
return TOKENIZER.decode(sample[0], truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])