|
from functools import partial |
|
from typing import Iterator, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.utils.parametrize as parametrize |
|
import math |
|
|
|
from torch.nn import Parameter |
|
|
|
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig |
|
|
|
|
|
def initialized_weights(shape, num_adaptions, init='kaiming'): |
|
weight_data = [] |
|
for _ in range(num_adaptions): |
|
new_adaption = torch.zeros(shape) |
|
if init == 'kaiming': |
|
nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5)) |
|
elif init == 'normal': |
|
nn.init.normal_(new_adaption) |
|
else: |
|
raise NotImplementedError |
|
weight_data.append(new_adaption) |
|
return torch.stack(weight_data, dim=0) |
|
|
|
|
|
class LoRAParametrization(nn.Module): |
|
def __init__(self, fan_in, fan_out, layer_type='linear', num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): |
|
super().__init__() |
|
|
|
|
|
fan_in_fan_out = (layer_type == 'embedding') |
|
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x) |
|
|
|
if layer_type == 'linear': |
|
self.lora_A = nn.Parameter(initialized_weights((rank, fan_in), num_adaptions, init='kaiming')) |
|
self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank))) |
|
elif layer_type == 'embedding': |
|
self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank))) |
|
self.lora_B = nn.Parameter(initialized_weights((rank, fan_out), num_adaptions=num_adaptions, init='normal')) |
|
else: |
|
raise NotImplementedError |
|
|
|
self.lora_alpha, self.rank = lora_alpha, rank |
|
self.scaling = lora_alpha / rank |
|
self.lora_dropout = nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x |
|
self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x |
|
self.register_buffer("lora_dropout_mask", torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype), persistent=False) |
|
self.forward_fn = lambda x: x |
|
self.current_task = None |
|
|
|
def _dropout(self, A): |
|
|
|
return A * self.lora_dropout(self.lora_dropout_mask) |
|
|
|
def lora_forward(self, X): |
|
assert self.current_task is not None |
|
return X + torch.matmul(*self.swap((self.lora_B[self.current_task], self.dropout_fn(self.lora_A[self.current_task])))).view(X.shape) * self.scaling |
|
|
|
def forward(self, X): |
|
return self.forward_fn(X) |
|
|
|
def select_task(self, task=None): |
|
self.current_task = task |
|
if task is None: |
|
self.forward_fn = lambda x: x |
|
else: |
|
self.forward_fn = self.lora_forward |
|
|
|
@classmethod |
|
def from_linear(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): |
|
fan_out, fan_in = layer.weight.shape |
|
return cls( |
|
fan_in, fan_out, num_adaptions=num_adaptions, layer_type='linear', rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha |
|
) |
|
|
|
@classmethod |
|
def from_embedding(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): |
|
fan_in, fan_out = layer.weight.shape |
|
return cls( |
|
fan_in, fan_out, num_adaptions=num_adaptions, layer_type='embedding', rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha |
|
) |
|
|
|
@classmethod |
|
def add_to_layer(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): |
|
if isinstance(layer, nn.Linear): |
|
parametrize.register_parametrization(layer, "weight", cls.from_linear(layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha)) |
|
elif isinstance(layer, nn.Embedding): |
|
parametrize.register_parametrization(layer, "weight", cls.from_embedding(layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha)) |
|
|
|
@classmethod |
|
def select_task_for_layer(cls, layer, task_idx=None): |
|
if isinstance(layer, LoRAParametrization): |
|
layer.select_task(task_idx) |
|
|
|
|
|
class BertLoRA(BertPreTrainedModel): |
|
def __init__(self, config: JinaBertConfig, add_pooling_layer=True, num_adaptions=1): |
|
super().__init__(config) |
|
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer) |
|
self._register_lora(num_adaptions) |
|
for name, param in super().named_parameters(): |
|
if 'lora' not in name: |
|
param.requires_grad_(False) |
|
|
|
def from_bert(self, *args, num_adaptions=1, **kwargs): |
|
self.bert = BertModel.from_pretrained(*args, **kwargs) |
|
self._register_lora(num_adaptions) |
|
|
|
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1): |
|
self.apply(partial(LoRAParametrization.add_to_layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha)) |
|
|
|
def select_task(self, task_idx): |
|
self.apply(partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)) |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.bert(*args, **kwargs) |
|
|
|
def parameters(self, recurse: bool = True) -> Iterator[Parameter]: |
|
for _, param in self.named_parameters(recurse=recurse): |
|
yield param |
|
|
|
def named_parameters( |
|
self, |
|
prefix: str = '', |
|
recurse: bool = True, |
|
remove_duplicate: bool = True |
|
) -> Iterator[Tuple[str, Parameter]]: |
|
for name, param in super().named_parameters(prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate): |
|
if 'lora' in name: |
|
yield name, param |