zetavg commited on
Commit
49b832f
1 Parent(s): d047590

cache loaded_base_model_with_lora

Browse files
Files changed (2) hide show
  1. llama_lora/globals.py +2 -0
  2. llama_lora/models.py +11 -0
llama_lora/globals.py CHANGED
@@ -27,6 +27,8 @@ class Global:
27
 
28
  # Model related
29
  model_has_been_used = False
 
 
30
 
31
  # GPU Info
32
  gpu_cc = None # GPU compute capability
 
27
 
28
  # Model related
29
  model_has_been_used = False
30
+ loaded_base_model_with_lora = None
31
+ loaded_base_model_with_lora_name = None
32
 
33
  # GPU Info
34
  gpu_cc = None # GPU compute capability
llama_lora/models.py CHANGED
@@ -34,6 +34,9 @@ def get_base_model():
34
  def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
35
  Global.model_has_been_used = True
36
 
 
 
 
37
  if device == "cuda":
38
  model = PeftModel.from_pretrained(
39
  get_base_model(),
@@ -65,6 +68,9 @@ def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
65
  model.eval()
66
  if torch.__version__ >= "2" and sys.platform != "win32":
67
  model = torch.compile(model)
 
 
 
68
  return model
69
 
70
 
@@ -121,6 +127,11 @@ def unload_models():
121
  del Global.loaded_tokenizer
122
  Global.loaded_tokenizer = None
123
 
 
 
 
 
 
124
  clear_cache()
125
 
126
  Global.model_has_been_used = False
 
34
  def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
35
  Global.model_has_been_used = True
36
 
37
+ if Global.loaded_base_model_with_lora and Global.loaded_base_model_with_lora_name == lora_weights:
38
+ return Global.loaded_base_model_with_lora
39
+
40
  if device == "cuda":
41
  model = PeftModel.from_pretrained(
42
  get_base_model(),
 
68
  model.eval()
69
  if torch.__version__ >= "2" and sys.platform != "win32":
70
  model = torch.compile(model)
71
+
72
+ Global.loaded_base_model_with_lora = model
73
+ Global.loaded_base_model_with_lora_name = lora_weights
74
  return model
75
 
76
 
 
127
  del Global.loaded_tokenizer
128
  Global.loaded_tokenizer = None
129
 
130
+ del Global.loaded_base_model_with_lora
131
+ Global.loaded_base_model_with_lora = None
132
+
133
+ Global.loaded_base_model_with_lora_name = None
134
+
135
  clear_cache()
136
 
137
  Global.model_has_been_used = False