feat: added docstrings
Browse files- modeling_lora.py +30 -2
modeling_lora.py
CHANGED
|
@@ -65,6 +65,8 @@ class LoRAParametrization(nn.Module):
|
|
| 65 |
fan_in_fan_out = layer_type == "embedding"
|
| 66 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
| 67 |
|
|
|
|
|
|
|
| 68 |
if layer_type == "linear":
|
| 69 |
self.lora_A = nn.Parameter(
|
| 70 |
initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
|
|
@@ -225,7 +227,15 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 225 |
return self._main_params_trainable
|
| 226 |
|
| 227 |
@main_params_trainable.setter
|
| 228 |
-
def main_params_trainable(self, val):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
self._main_params_trainable = val
|
| 230 |
for name, param in super().named_parameters():
|
| 231 |
if "lora" not in name:
|
|
@@ -259,7 +269,13 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 259 |
use_safetensors: bool = None,
|
| 260 |
**kwargs,
|
| 261 |
):
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
return cls.from_bert(pretrained_model_name_or_path)
|
| 264 |
|
| 265 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
|
@@ -275,14 +291,26 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 275 |
|
| 276 |
@property
|
| 277 |
def current_task(self):
|
|
|
|
|
|
|
|
|
|
| 278 |
return self._task_idx
|
| 279 |
|
| 280 |
@current_task.setter
|
| 281 |
def current_task(self, task_idx: Union[None, int]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
if self._is_merged:
|
| 283 |
raise Exception('LoRA has been merged, cannot select new task')
|
| 284 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
| 285 |
if self._task_idx != task_idx:
|
|
|
|
| 286 |
self._task_idx = task_idx
|
| 287 |
self.apply(
|
| 288 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
|
|
|
| 65 |
fan_in_fan_out = layer_type == "embedding"
|
| 66 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
| 67 |
|
| 68 |
+
# For the officially "correct" LoRA initialization, check here: https://github.com/microsoft/LoRA
|
| 69 |
+
# TODO: Ensure that the initialization here is correct
|
| 70 |
if layer_type == "linear":
|
| 71 |
self.lora_A = nn.Parameter(
|
| 72 |
initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
|
|
|
|
| 227 |
return self._main_params_trainable
|
| 228 |
|
| 229 |
@main_params_trainable.setter
|
| 230 |
+
def main_params_trainable(self, val: bool):
|
| 231 |
+
"""Whether the main parameters (i.e. those that are not LoRA) should be trainable.
|
| 232 |
+
|
| 233 |
+
This method sets the `requires_grad_` attribute of the main weights
|
| 234 |
+
and controls which parameters are returned in `self.parameters()`.
|
| 235 |
+
|
| 236 |
+
:param val: Whether or not to make the parameters trainable.
|
| 237 |
+
:return: None
|
| 238 |
+
"""
|
| 239 |
self._main_params_trainable = val
|
| 240 |
for name, param in super().named_parameters():
|
| 241 |
if "lora" not in name:
|
|
|
|
| 269 |
use_safetensors: bool = None,
|
| 270 |
**kwargs,
|
| 271 |
):
|
| 272 |
+
"""
|
| 273 |
+
TODO: choose between from_bert and super().from_pretrained
|
| 274 |
+
|
| 275 |
+
We want to be able to load both a pretrained BertModel, and a trained
|
| 276 |
+
BertLoRA via this method. To this end, we need to check which of these
|
| 277 |
+
models we are expected to load.
|
| 278 |
+
"""
|
| 279 |
return cls.from_bert(pretrained_model_name_or_path)
|
| 280 |
|
| 281 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
|
|
|
| 291 |
|
| 292 |
@property
|
| 293 |
def current_task(self):
|
| 294 |
+
""" Which LoRA is currently selected
|
| 295 |
+
:return: Integer or None (when LoRA is disabled)
|
| 296 |
+
"""
|
| 297 |
return self._task_idx
|
| 298 |
|
| 299 |
@current_task.setter
|
| 300 |
def current_task(self, task_idx: Union[None, int]):
|
| 301 |
+
"""Set the LoRA that is to be used.
|
| 302 |
+
|
| 303 |
+
The LoRA is specified by `task_idx`, which may be an integer >= 0,
|
| 304 |
+
indexing the available LoRAs. If it is None, no LoRA is used.
|
| 305 |
+
|
| 306 |
+
:param task_idx: Which LoRA to use
|
| 307 |
+
:return:
|
| 308 |
+
"""
|
| 309 |
if self._is_merged:
|
| 310 |
raise Exception('LoRA has been merged, cannot select new task')
|
| 311 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
| 312 |
if self._task_idx != task_idx:
|
| 313 |
+
# In this case, we need to update the LoRAs everywhere
|
| 314 |
self._task_idx = task_idx
|
| 315 |
self.apply(
|
| 316 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|