from typing import Any, Dict, List, Tuple import torch from copy import deepcopy from transformers import AutoModelForCausalLM, AutoTokenizer from .GRACE import GRACE from .grace_hparams import GraceHyperParams from .utils import tokenize from ...util import nethook import gradio as gr def apply_grace_to_model( model: AutoModelForCausalLM, tok: AutoTokenizer, requests: List[Dict], hparams: GraceHyperParams, num_steps: int, edit_lr: float, copy=False, return_orig_weights=False, keep_original_weight=False, **kwargs: Any, ) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]: request = requests if copy: model = deepcopy(model) weights_copy = {} device = torch.device('cpu') hparams.edit_lr = edit_lr editor = GRACE(model=model, config=hparams, device=device) tokens = tokenize(request, tokenizer=tok, device=device) editor.edit(config=hparams, tokens=tokens) # editor.to('cpu') gr.Info("Completed editing via GRACE!") return editor