stibiumghost commited on
Commit
afe352a
β€’
1 Parent(s): 04e9e1a

Update text_gen.py

Browse files
Files changed (1) hide show
  1. text_gen.py +2 -3
text_gen.py CHANGED
@@ -15,8 +15,7 @@ model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
15
 
16
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
18
- if context:
19
- text = f'{context} {text}'
20
  if 'GODEL' in model_name:
21
  text = f'Instruction: you need to response discreetly. [CONTEXT] {text}'
22
  text.replace('\t', ' EOS ')
@@ -34,4 +33,4 @@ def capitalization(line):
34
  line = f'{mark} '.join([part.strip()[0].upper() + part.strip()[1:] for part in line.split(mark) if len(part) > 1])
35
  line = ' '.join([word.capitalize() if word.translate(str.maketrans('', '', string.punctuation)) == 'i'
36
  else word for word in line.split()])
37
- return line + end
 
15
 
16
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
18
+ text = f'{context} {text}'
 
19
  if 'GODEL' in model_name:
20
  text = f'Instruction: you need to response discreetly. [CONTEXT] {text}'
21
  text.replace('\t', ' EOS ')
 
33
  line = f'{mark} '.join([part.strip()[0].upper() + part.strip()[1:] for part in line.split(mark) if len(part) > 1])
34
  line = ' '.join([word.capitalize() if word.translate(str.maketrans('', '', string.punctuation)) == 'i'
35
  else word for word in line.split()])
36
+ return line.replace(' i\'', ' I\'') + end