from transformers import Trainer import torch import os class MambaTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): input_ids = inputs.pop("input_ids") lm_logits = model(input_ids).logits labels = input_ids.to(lm_logits.device) shift_logits = lm_logits[:, :-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss() lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) return lm_loss def save_model(self, output_dir, _internal_call): if not os.path.exists(output_dir): os.makedirs(output_dir) torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin") self.tokenizer.save_pretrained(output_dir) json_str = """ { "d_model": 768, "n_layer": 24, "vocab_size": 50277, "ssm_cfg": {}, "rms_norm": true, "residual_in_fp32": true, "fused_add_norm": true, "pad_vocab_size_multiple": 8 }""" with open(f"{output_dir}/config.json", 'w') as f: f.write(json_str)