Codex_Prime / src /model.py
dnnsdunca's picture
Update src/model.py
cd17556 verified
raw
history blame
579 Bytes
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
def get_model_and_tokenizer(config):
model = AutoModelForCausalLM.from_pretrained(config['model']['name'])
tokenizer = AutoTokenizer.from_pretrained(config['model']['name'])
# Add LoRA adapters for fine-tuning
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
return model, tokenizer