John6666 commited on
Commit
f31fb52
β€’
1 Parent(s): 3733821

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +17 -15
README.md CHANGED
@@ -5,7 +5,7 @@ language:
5
  ---
6
  # Image Captioning App
7
 
8
- This is a mod of [Wi-zz/joy-caption-pre-alpha](https://huggingface.co/Wi-zz/joy-caption-pre-alpha) and [fancyfeast/joy-caption-alpha-two](https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-two). Thanks to [dominic1021](https://huggingface.co/dominic1021).
9
 
10
  # Notice: I will contribute to Wi-zz after shaping the code.
11
 
 
5
  ---
6
  # Image Captioning App
7
 
8
+ This is a mod of [Wi-zz/joy-caption-pre-alpha](https://huggingface.co/Wi-zz/joy-caption-pre-alpha) and [fancyfeast/joy-caption-alpha-two](https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-two). Thanks to [dominic1021](https://huggingface.co/dominic1021), [IceHibiki](https://huggingface.co/IceHibiki).
9
 
10
  # Notice: I will contribute to Wi-zz after shaping the code.
11
 
app.py CHANGED
@@ -12,7 +12,7 @@ from torch import nn
12
  from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
13
  from typing import List, Union
14
  import torchvision.transforms.functional as TVF
15
- from peft import PeftConfig
16
  import gc
17
 
18
  # Constants
@@ -127,6 +127,7 @@ class ImageAdapter(nn.Module):
127
 
128
  # Global Variables
129
  IS_NF4 = True
 
130
  MODEL_PATH = DEFAULT_MODEL_PATH
131
  device = "cuda" if torch.cuda.is_available() else "cpu"
132
  print(f"Running on {device}")
@@ -196,7 +197,7 @@ class ImageAdapter(nn.Module):
196
  return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
197
 
198
  def load_models():
199
- global MODEL_PATH, IS_NF4
200
  try:
201
  if IS_NF4:
202
  from transformers import BitsAndBytesConfig
@@ -222,11 +223,11 @@ def load_models():
222
  print(f"Loading LLM: {MODEL_PATH} πŸ€–")
223
  text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
224
 
225
- if LORA_PATH.exists():
226
  print("Loading VLM's custom text model πŸ€–")
227
- peft_config = PeftConfig.from_pretrained(LORA_PATH, device_map=device, quantization_config=nf4_config)
228
- text_model.add_adapter(peft_config)
229
- text_model.enable_adapters()
230
 
231
  print("Loading image adapter πŸ–ΌοΈ")
232
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
@@ -250,13 +251,13 @@ def load_models():
250
  assert isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)), f"Tokenizer is of type {type(tokenizer)}"
251
 
252
  print(f"Loading LLM: {MODEL_PATH} πŸ€–")
253
- text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16).eval()
254
 
255
- if LORA_PATH.exists():
256
  print("Loading VLM's custom text model πŸ€–")
257
- peft_config = PeftConfig.from_pretrained(LORA_PATH, device_map=device)
258
- text_model.add_adapter(peft_config)
259
- text_model.enable_adapters()
260
 
261
  print("Loading image adapter πŸ–ΌοΈ")
262
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
@@ -461,7 +462,8 @@ def parse_arguments():
461
  parser.add_argument("--prompt", type=str, default="", help='Custom Prompt (optional, will override all other settings)')
462
  parser.add_argument("--model", type=str, default=DEFAULT_MODEL_PATH,
463
  help='Huggingface LLM repo (default: "unsloth/Meta-Llama-3.1-8B-bnb-4bit")')
464
- parser.add_argument("--bf16", action="store_true", help="Use bfloat16 (default: NF4)")
 
465
  parser.add_argument("--tokens", type=int, default=300, help="Max tokens (default: 300)")
466
  parser.add_argument("--topp", type=float, default=0.9, help="Top-P (default: 0.9)")
467
  parser.add_argument("--temp", type=float, default=0.6, help="Temperature (default: 0.6)")
@@ -480,7 +482,7 @@ def is_valid_repo(repo_id):
480
  return False
481
 
482
  def main():
483
- global MODEL_PATH, IS_NF4
484
  args = parse_arguments()
485
  input_paths = [Path(input_path) for input_path in args.input]
486
  batch_size = args.bs
@@ -492,8 +494,8 @@ def main():
492
  max_new_tokens = args.tokens
493
  top_p = args.topp
494
  temperature = args.temp
495
- if args.bf16: IS_NF4 = False
496
- else: IS_NF4 = True
497
  if is_valid_repo(args.model): MODEL_PATH = args.model
498
  else: sys.exit(1)
499
  models = load_models()
 
12
  from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
13
  from typing import List, Union
14
  import torchvision.transforms.functional as TVF
15
+ from peft import PeftModel
16
  import gc
17
 
18
  # Constants
 
127
 
128
  # Global Variables
129
  IS_NF4 = True
130
+ IS_LORA = True
131
  MODEL_PATH = DEFAULT_MODEL_PATH
132
  device = "cuda" if torch.cuda.is_available() else "cpu"
133
  print(f"Running on {device}")
 
197
  return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
198
 
199
  def load_models():
200
+ global MODEL_PATH, IS_NF4, IS_LORA
201
  try:
202
  if IS_NF4:
203
  from transformers import BitsAndBytesConfig
 
223
  print(f"Loading LLM: {MODEL_PATH} πŸ€–")
224
  text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
225
 
226
+ if False and IS_LORA and LORA_PATH.exists(): # omitted
227
  print("Loading VLM's custom text model πŸ€–")
228
+ text_model = PeftModel.from_pretrained(model=text_model, model_id=LORA_PATH, device_map=device, quantization_config=nf4_config)
229
+ text_model = text_model.merge_and_unload(safe_merge=True) # to avoid PEFT bug https://github.com/huggingface/transformers/issues/28515
230
+ else: print("VLM's custom text model isn't loaded πŸ€–")
231
 
232
  print("Loading image adapter πŸ–ΌοΈ")
233
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
 
251
  assert isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)), f"Tokenizer is of type {type(tokenizer)}"
252
 
253
  print(f"Loading LLM: {MODEL_PATH} πŸ€–")
254
+ text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16).eval() # device_map="auto" may cause LoRA issue
255
 
256
+ if IS_LORA and LORA_PATH.exists():
257
  print("Loading VLM's custom text model πŸ€–")
258
+ text_model = PeftModel.from_pretrained(model=text_model, model_id=LORA_PATH, device_map=device)
259
+ text_model = text_model.merge_and_unload(safe_merge=True) # to avoid PEFT bug https://github.com/huggingface/transformers/issues/28515
260
+ else: print("VLM's custom text model isn't loaded πŸ€–")
261
 
262
  print("Loading image adapter πŸ–ΌοΈ")
263
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
 
462
  parser.add_argument("--prompt", type=str, default="", help='Custom Prompt (optional, will override all other settings)')
463
  parser.add_argument("--model", type=str, default=DEFAULT_MODEL_PATH,
464
  help='Huggingface LLM repo (default: "unsloth/Meta-Llama-3.1-8B-bnb-4bit")')
465
+ parser.add_argument("--bf16", action="store_true", default=False, help="Use bfloat16 (default: NF4)")
466
+ parser.add_argument("--nolora", action="store_true", default=False, help="Disable VLM's custom text model (default: Enable)")
467
  parser.add_argument("--tokens", type=int, default=300, help="Max tokens (default: 300)")
468
  parser.add_argument("--topp", type=float, default=0.9, help="Top-P (default: 0.9)")
469
  parser.add_argument("--temp", type=float, default=0.6, help="Temperature (default: 0.6)")
 
482
  return False
483
 
484
  def main():
485
+ global MODEL_PATH, IS_NF4, IS_LORA
486
  args = parse_arguments()
487
  input_paths = [Path(input_path) for input_path in args.input]
488
  batch_size = args.bs
 
494
  max_new_tokens = args.tokens
495
  top_p = args.topp
496
  temperature = args.temp
497
+ IS_NF4 = False if args.bf16 else True
498
+ IS_LORA = False if args.nolora else True
499
  if is_valid_repo(args.model): MODEL_PATH = args.model
500
  else: sys.exit(1)
501
  models = load_models()