# Author: Ricardo Lisboa Santos # Creation date: 2024-01-10 import torch # import torch_directml from transformers import pipeline def getDevice(DEVICE): device = None if DEVICE == "cpu": device = torch.device("cpu") dtype = torch.float32 elif DEVICE == "cuda": device = torch.device("cuda") dtype = torch.float16 # elif DEVICE == "directml": # device = torch_directml.device() # dtype = torch.float16 return device def loadGenerator(device): generator = pipeline("text-generation") # .to(device) return generator def generate(generator, text): output = generator(text) return output def clearCache(DEVICE, generator): generator.tokenizer.save_pretrained("cache") generator.model.save_pretrained("cache") del generator # if DEVICE == "directml": # torch_directml.empty_cache()