File size: 1,070 Bytes
3494c6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6682b6
3494c6b
 
 
 
 
 
 
 
 
 
d6682b6
3494c6b
 
 
 
 
d6682b6
3494c6b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from typing import Any, Dict, List, Tuple
import torch
from copy import deepcopy
from transformers import AutoModelForCausalLM, AutoTokenizer
from .GRACE import GRACE
from .grace_hparams import GraceHyperParams
from .utils import tokenize
from ...util import nethook
import gradio as gr


def apply_grace_to_model(
        model: AutoModelForCausalLM,
        tok: AutoTokenizer,
        requests: List[Dict],
        hparams: GraceHyperParams,
        num_steps: int,
        edit_lr: float,
        copy=False,
        return_orig_weights=False,
        keep_original_weight=False,
        **kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
    request = requests
    if copy:
        model = deepcopy(model)
    weights_copy = {}
    device = torch.device('cpu')
    hparams.edit_lr = edit_lr
    editor = GRACE(model=model, config=hparams, device=device)

    tokens = tokenize(request, tokenizer=tok, device=device)
    editor.edit(config=hparams, tokens=tokens)

    # editor.to('cpu')
    gr.Info("Completed editing via GRACE!")
    return editor