nroggendorff commited on
Commit
411ad3b
1 Parent(s): 06ebaba

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +27 -12
train.py CHANGED
@@ -56,6 +56,25 @@ def load_data():
56
  dataset = Dataset.from_dict({'text': [example['text'] for example in data_list]})
57
  return dataset
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def create_tokenizer(training_corpus):
60
  tokenizer = ByteLevelBPETokenizer()
61
  special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
@@ -88,22 +107,16 @@ def format_prompts(examples, tokenizer, isinst):
88
  conversation.append({"role": "user", "content": prompt})
89
  conversation.append({"role": "assistant", "content": response})
90
  formatted_conversation = tokenizer.apply_chat_template(conversation, tokenize=False)
91
- texts.append(formatted_conversation)
 
92
  else:
93
- texts.append(tokenizer.bos_token + text + tokenizer.eos_token)
94
  else:
95
  print('Found empty entry in examples. Moving on..')
96
  continue
97
- tokenized_texts = tokenizer(
98
- texts,
99
- padding="max_length",
100
- truncation=True,
101
- max_length=MAX_SEQ_LENGTH,
102
- return_tensors="pt"
103
- ).input_ids
104
- decoded_texts = tokenizer.batch_decode(tokenized_texts)
105
-
106
- return {'text': decoded_texts}
107
 
108
  def create_model(tokenizer):
109
  config = LlamaConfig(
@@ -146,6 +159,8 @@ def configure_tokenizer(tokenizer):
146
  chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
147
  tokenizer.chat_template = chat_template
148
 
 
 
149
  def update_tokenizer(tokenizer, dataset, batch_size=1000):
150
  existing_vocab = tokenizer.get_vocab()
151
  oov_tokens = set()
 
56
  dataset = Dataset.from_dict({'text': [example['text'] for example in data_list]})
57
  return dataset
58
 
59
+ def encode_decode(text, tok):
60
+ tokenized_texts = tokenizer(
61
+ texts,
62
+ padding="max_length",
63
+ truncation=True,
64
+ max_length=MAX_SEQ_LENGTH,
65
+ return_tensors="pt"
66
+ ).input_ids
67
+
68
+ if tokenized_texts.dim() >= 1:
69
+ decoded_texts = tokenizer.batch_decode(tokenized_texts)
70
+ else:
71
+ print('Found invalid entry in examples. Returning dummy..')
72
+ decoded_texts = ['Nothing to see here.']
73
+
74
+ islist = not len(decoded_texts) == 1
75
+
76
+ return decoded_texts if islist else decoded_texts[0]
77
+
78
  def create_tokenizer(training_corpus):
79
  tokenizer = ByteLevelBPETokenizer()
80
  special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
 
107
  conversation.append({"role": "user", "content": prompt})
108
  conversation.append({"role": "assistant", "content": response})
109
  formatted_conversation = tokenizer.apply_chat_template(conversation, tokenize=False)
110
+ coded_text = tokenizer.code(formatted_conversation)
111
+ texts.append(coded_text)
112
  else:
113
+ texts.append(tokenizer.bos_token + tokenizer.code(text) + tokenizer.eos_token)
114
  else:
115
  print('Found empty entry in examples. Moving on..')
116
  continue
117
+
118
+ coded_texts = tokenizer.code(texts)
119
+ return {'text': coded_texts}
 
 
 
 
 
 
 
120
 
121
  def create_model(tokenizer):
122
  config = LlamaConfig(
 
159
  chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
160
  tokenizer.chat_template = chat_template
161
 
162
+ tokenizer.code = lambda example: encode_decode(example, tokenizer)
163
+
164
  def update_tokenizer(tokenizer, dataset, batch_size=1000):
165
  existing_vocab = tokenizer.get_vocab()
166
  oov_tokens = set()