{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "a41f141c-b6a8-40d1-b72d-127d028c0592",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"model_path = os.getcwd()\n",
"print(model_path)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False)\n",
"model = AutoModelForCausalLM.from_pretrained(model_path, use_safetensors=True, local_files_only=True)\n",
"tokenizer.pad_token = tokenizer.eos_token"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "93e9ec6a-4a57-484f-a1a5-ecb6674e8f77",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LlamaTokenizerFast(name_or_path='/var/home/ngxson/jupyter/stories-15M', vocab_size=32000, model_max_length=2048, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '', 'eos_token': '', 'unk_token': ''}, clean_up_tokenization_spaces=False), added_tokens_decoder={\n",
"\t0: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
"\t1: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
"\t2: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
"}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#inputs = tokenizer('', return_tensors=\"pt\")\n",
"#outputs = model.generate(inputs['input_ids'], max_new_tokens=20, temperature=0)\n",
"#print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
"\n",
"tokenizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e570b6db-efa8-4c9f-ac71-573479b00711",
"metadata": {},
"outputs": [],
"source": [
"model.gradient_checkpointing_enable()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9345e74b-5bef-4cc9-982e-342af69b290a",
"metadata": {},
"outputs": [],
"source": [
"from peft import LoraConfig, get_peft_model\n",
"\n",
"peft_config = LoraConfig(\n",
" r=64,\n",
" lora_alpha=128,\n",
" target_modules=[\n",
" \"q_proj\",\n",
" \"k_proj\",\n",
" \"v_proj\",\n",
" \"o_proj\",\n",
" \"w1\",\n",
" \"w2\",\n",
" \"w3\",\n",
" \"lm_head\",\n",
" ],\n",
" bias=\"none\",\n",
" lora_dropout=0.05, # Conventional\n",
" task_type=\"CAUSAL_LM\",\n",
")\n",
"\n",
"model = get_peft_model(model, peft_config)\n",
"model.print_trainable_parameters()\n",
"\n",
"#print(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b43aec47-5fa4-48c9-8e57-9c6b233b9c7e",
"metadata": {},
"outputs": [],
"source": [
"def split_and_trim(text):\n",
" paragraphs = text.strip().split('\\n\\n')\n",
" trimmed_paragraphs = []\n",
" for para in paragraphs:\n",
" trimmed_lines = [line.lstrip() for line in para.split('\\n')]\n",
" trimmed_paragraphs.append('\\n'.join(trimmed_lines))\n",
"\n",
" return trimmed_paragraphs\n",
"\n",
"with open(\"data.txt\", \"r\") as f:\n",
" content = f.read()\n",
" dataset = split_and_trim(content)\n",
" tokenized_train_dataset = [\n",
" tokenizer(content)['input_ids'] for content in dataset\n",
" ]\n",
"#tokenized_train_dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09dd4848-9c7a-4a3b-9887-59652c915cc3",
"metadata": {},
"outputs": [],
"source": [
"import transformers\n",
"from datetime import datetime\n",
"\n",
"project = \"moe_shakespeare15M\"\n",
"run_name = project\n",
"output_dir = \"./\" + run_name\n",
"\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
"checkpointing_args = {\"use_reentrant\": False}\n",
"trainer = transformers.Trainer(\n",
" model=model,\n",
" train_dataset=tokenized_train_dataset,\n",
" args=transformers.TrainingArguments(\n",
" output_dir=output_dir,\n",
" warmup_steps=100,\n",
" per_device_train_batch_size=50,\n",
" gradient_accumulation_steps=5,\n",
" gradient_checkpointing=True,\n",
" max_steps=500,\n",
" learning_rate=2.5e-5, # Want a small lr for finetuning\n",
" # fp16=True, \n",
" optim=\"adamw_torch\",\n",
" save_strategy=\"steps\",\n",
" save_steps=100,\n",
" logging_steps=20,\n",
" save_total_limit=4,\n",
" report_to=\"none\", \n",
" run_name=f\"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}\"\n",
" ),\n",
" data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
")\n",
"\n",
"model.config.use_cache = False # silence the warnings. Please re-enable for inference!\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f0ad783-3f3e-4812-bc4e-026f9aad1435",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}