File size: 1,550 Bytes
239290b a1025a5 239290b 7bbb8c5 239290b 7bbb8c5 858f052 7152f53 7bbb8c5 239290b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from transformers import AutoTokenizer, TFGPT2LMHeadModel, pipeline
from transformers.pipelines import TextGenerationPipeline
from typing import Union
class QuoteGenerator():
def __init__(self, model_name:str='gruhit-patel/quote-generator-v2'):
self.model_name = model_name
self.quote_generator: TextGenerationPipeline
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = TFGPT2LMHeadModel.from_pretrained(self.model_name)
self.default_tags = 'love,life'
print("Model has been loaded")
def load_generator(self) -> None:
self.quote_generator = pipeline('text-generation', model=self.model, tokenizer=self.tokenizer)
print("Pipeline has been generated")
def preprocess_tags(self, tags: Union[None, str] = None) -> str:
if tags is None:
tags = self.default_tags
return self.tokenizer.bos_token + tags + '<bot>:'
def generate_quote(self, tags:Union[None, str], max_new_tokens: int, do_sample: bool,
num_beams: int, top_k: int, top_p: float, temperature: float):
tags = self.preprocess_tags(tags)
output = self.quote_generator(
tags,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample = do_sample,
early_stopping=num_beams > 1
)
return output[0]['generated_text']
|