feat: added method to merge LoRA weights
Browse files- modeling_lora.py +16 -0
modeling_lora.py
CHANGED
@@ -199,6 +199,12 @@ class LoRAParametrization(nn.Module):
|
|
199 |
if isinstance(layer, LoRAParametrization):
|
200 |
layer.current_task = task_idx
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
class BertLoRA(BertPreTrainedModel):
|
204 |
def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True):
|
@@ -207,6 +213,7 @@ class BertLoRA(BertPreTrainedModel):
|
|
207 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
208 |
else:
|
209 |
self.bert = bert
|
|
|
210 |
self._num_adaptions = config.num_loras
|
211 |
self._register_lora(self._num_adaptions)
|
212 |
self.main_params_trainable = False
|
@@ -230,6 +237,13 @@ class BertLoRA(BertPreTrainedModel):
|
|
230 |
config = JinaBertConfig.from_pretrained(*args, **kwargs)
|
231 |
return cls(config, bert=bert, num_adaptions=num_adaptions)
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
@classmethod
|
234 |
def from_pretrained(
|
235 |
cls,
|
@@ -265,6 +279,8 @@ class BertLoRA(BertPreTrainedModel):
|
|
265 |
|
266 |
@current_task.setter
|
267 |
def current_task(self, task_idx: Union[None, int]):
|
|
|
|
|
268 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
269 |
if self._task_idx != task_idx:
|
270 |
self._task_idx = task_idx
|
|
|
199 |
if isinstance(layer, LoRAParametrization):
|
200 |
layer.current_task = task_idx
|
201 |
|
202 |
+
@classmethod
|
203 |
+
def merge_lora_into_layer(cls, layer: nn.Module):
|
204 |
+
if hasattr(layer, "parametrizations"):
|
205 |
+
for attr_name in layer.parametrizations.keys():
|
206 |
+
parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=True)
|
207 |
+
|
208 |
|
209 |
class BertLoRA(BertPreTrainedModel):
|
210 |
def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True):
|
|
|
213 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
214 |
else:
|
215 |
self.bert = bert
|
216 |
+
self._is_merged = False
|
217 |
self._num_adaptions = config.num_loras
|
218 |
self._register_lora(self._num_adaptions)
|
219 |
self.main_params_trainable = False
|
|
|
237 |
config = JinaBertConfig.from_pretrained(*args, **kwargs)
|
238 |
return cls(config, bert=bert, num_adaptions=num_adaptions)
|
239 |
|
240 |
+
def merge_lora(self):
|
241 |
+
"""Merges currently selected LoRA into main weights."""
|
242 |
+
if self._is_merged:
|
243 |
+
raise Exception('LoRA has already been merged, cannot merge again')
|
244 |
+
self._is_merged = True
|
245 |
+
self.apply(LoRAParametrization.merge_lora_into_layer)
|
246 |
+
|
247 |
@classmethod
|
248 |
def from_pretrained(
|
249 |
cls,
|
|
|
279 |
|
280 |
@current_task.setter
|
281 |
def current_task(self, task_idx: Union[None, int]):
|
282 |
+
if self._is_merged:
|
283 |
+
raise Exception('LoRA has been merged, cannot select new task')
|
284 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
285 |
if self._task_idx != task_idx:
|
286 |
self._task_idx = task_idx
|