from transformers import AutoModel, AutoTokenizer, AutoConfig, AdamW, get_linear_schedule_with_warmup from torch.utils.data import DataLoader import transformers from sklearn.model_selection import train_test_split from datasets import load_dataset, DatasetDict import torch.nn as nn import torch import wandb from tqdm import tqdm args_max_epoch = 1 args_batch_size = 64 args_learning_rate = 3e-5 args_num_warmup_steps = 100 args_gradient_accumulation_steps_default = 2 adapter_hidden_dim = 4096 device = 'cuda' def main(): wandb.init(project="MappingAdapater_training_v6", name="training_run") model = MappingStructure(checkpointE = "sentence-transformers/stsb-roberta-large", checkpointD = "mistralai/Mistral-7B-Instruct-v0.1", hidden_dim = adapter_hidden_dim, torch_dtype = torch.float16, flash_attn = True, ).to(device) for n,p in model.named_parameters(): if 'mapping' not in n: p.requires_grad = False else: p.requires_grad = True dataset = load_dataset("sade-adrien/redpajama_v2_sample_10M")['train'] train_dataset, val_dataset = split_dataset(dataset, train_size=.989333) datasets = DatasetDict({ 'train': train_dataset, 'val': val_dataset }) train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True) val_dataloader = DataLoader(datasets['val'], batch_size=args_batch_size, shuffle=False) optimizer = AdamW(model.parameters(), lr=args_learning_rate) scheduler = get_linear_schedule_with_warmup(optimizer, args_num_warmup_steps, args_max_epoch*len(train_dataloader)) global_step = 0 for epoch in range(args_max_epoch): train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True, worker_init_fn=lambda _: torch.manual_seed(epoch)) for batch in tqdm(train_dataloader): input_prompt = batch['raw_content'] outputs = model(input_prompt=input_prompt, compute_loss=True) loss = outputs['loss'] # Gradient accumulation loss = loss / args_gradient_accumulation_steps_default loss.backward() if (global_step + 1) % args_gradient_accumulation_steps_default == 0: optimizer.step() optimizer.zero_grad() scheduler.step() if (global_step + 1) % 2000 == 0: torch.save({ 'epoch': epoch, 'mapping_state_dict': model.mapping.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'global_step': global_step, }, f'models/mapping_adapter_checkpoint_{global_step + 1}steps.pth') global_step += 1 val_loss = None if (global_step + 1) % 8000 == 0: model.eval() val_loss = 0.0 with torch.no_grad(): for val_batch in tqdm(val_dataloader): val_inputs = val_batch['raw_content'] val_outputs = model(input_prompt=val_inputs, compute_loss=True) val_loss += val_outputs['loss'] val_loss /= len(val_dataloader) model.train() wandb.log({ 'step': global_step + 1, 'learning_rate': scheduler.get_last_lr()[0], 'train_loss': loss.item() * args_gradient_accumulation_steps_default, 'val_loss': val_loss.item() if val_loss else None }) def split_dataset(dataset, train_size=.9): index = int(len(dataset) * train_size) return dataset.select(range(index)), dataset.select(range(index, len(dataset))) class MappingAdapter(nn.Module): def __init__(self, input_dim, output_dim, hidden_dim): super(MappingAdapter, self).__init__() self.layer1 = nn.Linear(input_dim, hidden_dim) self.layer2 = nn.Linear(hidden_dim, output_dim) self.activation = nn.LeakyReLU(.01) def forward(self, x): x = self.layer1(x) x = self.activation(x) x = self.layer2(x) return x class MappingStructure(nn.Module): def __init__(self, checkpointE, checkpointD, hidden_dim=2048, torch_dtype=torch.float32, flash_attn=False): super(MappingStructure, self).__init__() self.configE = AutoConfig.from_pretrained(checkpointE) self.Encoder = AutoModel.from_pretrained(checkpointE, low_cpu_mem_usage = True, torch_dtype = torch_dtype, config = self.configE ) self.configD = AutoConfig.from_pretrained(checkpointD) if flash_attn: self.configD.update({'_flash_attn_2_enabled' : True}) self.Decoder = AutoModel.from_pretrained(checkpointD, low_cpu_mem_usage = True, torch_dtype = torch_dtype, config = self.configD ) self.mapping = MappingAdapter(self.configD.hidden_size, self.configE.hidden_size, hidden_dim=hidden_dim).to(torch_dtype) self._init_tokenizers(checkpointE, checkpointD) def _init_tokenizers(self, checkpointE, checkpointD): self.tokenizerE = AutoTokenizer.from_pretrained(checkpointE, use_fast = False, revision = 'main', config = self.configE, padding_side='left') self.tokenizerD = AutoTokenizer.from_pretrained(checkpointD, use_fast = False, revision = 'main', config = self.configD, padding_side='left') self.tokenizerD.pad_token_id = self.tokenizerD.unk_token_id def cosine_sim(self, u, v): assert u.shape == v.shape, "u and v must have the same shape" u_normalized = u / torch.norm(u, dim=1, keepdim=True) v_normalized = v / torch.norm(v, dim=1, keepdim=True) # Compute cosine similarity using dot product return torch.sum(u_normalized * v_normalized, dim=1) def mean_pooling(self, hidden_state, attention_mask): token_embeddings = hidden_state input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def build_batch(self, input_prompt): size = torch.randint(1, self.configE.max_position_embeddings-2, (1,)).item() targets = [] for prompt in input_prompt: tokenized_input = self.tokenizerE(prompt) tokenized_input = {'input_ids': tokenized_input['input_ids'][:size], 'attention_mask': tokenized_input['attention_mask'][:size], } targets.append(tokenized_input) targets = self.tokenizerE.pad(targets, padding=True, return_tensors='pt') return targets def forward(self, input_prompt, compute_loss=False): loss = None # Slice prompt of needed to fit encoder max position embeddings (hard constraint) if not compute_loss: inputs = self.tokenizerD(input_prompt, return_tensors='pt', padding=True).to(device) hidden_state_D = self.Decoder(**inputs).last_hidden_state hidden_state_D_mapped = self.mapping(hidden_state_D) else: targets = self.build_batch(input_prompt).to(device) input_prompt_sliced = self.tokenizerE.batch_decode(targets['input_ids'], skip_special_tokens=True) inputs = self.tokenizerD(input_prompt_sliced, return_tensors='pt', padding=True).to(device) hidden_state_D = self.Decoder(**inputs).last_hidden_state hidden_state_D_mapped = self.mapping(hidden_state_D) hidden_state_E = self.Encoder(**targets).last_hidden_state proj_E = self.mean_pooling(hidden_state_E, targets['attention_mask']) proj_D = self.mean_pooling(hidden_state_D_mapped, inputs['attention_mask']) loss = 1 - torch.mean(self.cosine_sim(proj_E, proj_D)) del inputs del targets del input_prompt_sliced del hidden_state_E del proj_E del proj_D torch.cuda.empty_cache() return {'loss': loss, 'last_hidden_state': hidden_state_D, 'last_hidden_state_mapped': hidden_state_D_mapped, } if __name__ == '__main__': main()