Upload 2 files
Browse files
README.md
CHANGED
@@ -5,7 +5,7 @@ language:
|
|
5 |
---
|
6 |
# Image Captioning App
|
7 |
|
8 |
-
This is a mod of [Wi-zz/joy-caption-pre-alpha](https://huggingface.co/Wi-zz/joy-caption-pre-alpha) and [fancyfeast/joy-caption-alpha-two](https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-two). Thanks to [dominic1021](https://huggingface.co/dominic1021).
|
9 |
|
10 |
# Notice: I will contribute to Wi-zz after shaping the code.
|
11 |
|
|
|
5 |
---
|
6 |
# Image Captioning App
|
7 |
|
8 |
+
This is a mod of [Wi-zz/joy-caption-pre-alpha](https://huggingface.co/Wi-zz/joy-caption-pre-alpha) and [fancyfeast/joy-caption-alpha-two](https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-two). Thanks to [dominic1021](https://huggingface.co/dominic1021), [IceHibiki](https://huggingface.co/IceHibiki).
|
9 |
|
10 |
# Notice: I will contribute to Wi-zz after shaping the code.
|
11 |
|
app.py
CHANGED
@@ -12,7 +12,7 @@ from torch import nn
|
|
12 |
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
|
13 |
from typing import List, Union
|
14 |
import torchvision.transforms.functional as TVF
|
15 |
-
from peft import
|
16 |
import gc
|
17 |
|
18 |
# Constants
|
@@ -127,6 +127,7 @@ class ImageAdapter(nn.Module):
|
|
127 |
|
128 |
# Global Variables
|
129 |
IS_NF4 = True
|
|
|
130 |
MODEL_PATH = DEFAULT_MODEL_PATH
|
131 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
132 |
print(f"Running on {device}")
|
@@ -196,7 +197,7 @@ class ImageAdapter(nn.Module):
|
|
196 |
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
|
197 |
|
198 |
def load_models():
|
199 |
-
global MODEL_PATH, IS_NF4
|
200 |
try:
|
201 |
if IS_NF4:
|
202 |
from transformers import BitsAndBytesConfig
|
@@ -222,11 +223,11 @@ def load_models():
|
|
222 |
print(f"Loading LLM: {MODEL_PATH} π€")
|
223 |
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
|
224 |
|
225 |
-
if LORA_PATH.exists():
|
226 |
print("Loading VLM's custom text model π€")
|
227 |
-
|
228 |
-
text_model.
|
229 |
-
|
230 |
|
231 |
print("Loading image adapter πΌοΈ")
|
232 |
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
|
@@ -250,13 +251,13 @@ def load_models():
|
|
250 |
assert isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)), f"Tokenizer is of type {type(tokenizer)}"
|
251 |
|
252 |
print(f"Loading LLM: {MODEL_PATH} π€")
|
253 |
-
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16).eval()
|
254 |
|
255 |
-
if LORA_PATH.exists():
|
256 |
print("Loading VLM's custom text model π€")
|
257 |
-
|
258 |
-
text_model.
|
259 |
-
|
260 |
|
261 |
print("Loading image adapter πΌοΈ")
|
262 |
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
|
@@ -461,7 +462,8 @@ def parse_arguments():
|
|
461 |
parser.add_argument("--prompt", type=str, default="", help='Custom Prompt (optional, will override all other settings)')
|
462 |
parser.add_argument("--model", type=str, default=DEFAULT_MODEL_PATH,
|
463 |
help='Huggingface LLM repo (default: "unsloth/Meta-Llama-3.1-8B-bnb-4bit")')
|
464 |
-
parser.add_argument("--bf16", action="store_true", help="Use bfloat16 (default: NF4)")
|
|
|
465 |
parser.add_argument("--tokens", type=int, default=300, help="Max tokens (default: 300)")
|
466 |
parser.add_argument("--topp", type=float, default=0.9, help="Top-P (default: 0.9)")
|
467 |
parser.add_argument("--temp", type=float, default=0.6, help="Temperature (default: 0.6)")
|
@@ -480,7 +482,7 @@ def is_valid_repo(repo_id):
|
|
480 |
return False
|
481 |
|
482 |
def main():
|
483 |
-
global MODEL_PATH, IS_NF4
|
484 |
args = parse_arguments()
|
485 |
input_paths = [Path(input_path) for input_path in args.input]
|
486 |
batch_size = args.bs
|
@@ -492,8 +494,8 @@ def main():
|
|
492 |
max_new_tokens = args.tokens
|
493 |
top_p = args.topp
|
494 |
temperature = args.temp
|
495 |
-
if args.bf16
|
496 |
-
|
497 |
if is_valid_repo(args.model): MODEL_PATH = args.model
|
498 |
else: sys.exit(1)
|
499 |
models = load_models()
|
|
|
12 |
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
|
13 |
from typing import List, Union
|
14 |
import torchvision.transforms.functional as TVF
|
15 |
+
from peft import PeftModel
|
16 |
import gc
|
17 |
|
18 |
# Constants
|
|
|
127 |
|
128 |
# Global Variables
|
129 |
IS_NF4 = True
|
130 |
+
IS_LORA = True
|
131 |
MODEL_PATH = DEFAULT_MODEL_PATH
|
132 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
133 |
print(f"Running on {device}")
|
|
|
197 |
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
|
198 |
|
199 |
def load_models():
|
200 |
+
global MODEL_PATH, IS_NF4, IS_LORA
|
201 |
try:
|
202 |
if IS_NF4:
|
203 |
from transformers import BitsAndBytesConfig
|
|
|
223 |
print(f"Loading LLM: {MODEL_PATH} π€")
|
224 |
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
|
225 |
|
226 |
+
if False and IS_LORA and LORA_PATH.exists(): # omitted
|
227 |
print("Loading VLM's custom text model π€")
|
228 |
+
text_model = PeftModel.from_pretrained(model=text_model, model_id=LORA_PATH, device_map=device, quantization_config=nf4_config)
|
229 |
+
text_model = text_model.merge_and_unload(safe_merge=True) # to avoid PEFT bug https://github.com/huggingface/transformers/issues/28515
|
230 |
+
else: print("VLM's custom text model isn't loaded π€")
|
231 |
|
232 |
print("Loading image adapter πΌοΈ")
|
233 |
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
|
|
|
251 |
assert isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)), f"Tokenizer is of type {type(tokenizer)}"
|
252 |
|
253 |
print(f"Loading LLM: {MODEL_PATH} π€")
|
254 |
+
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16).eval() # device_map="auto" may cause LoRA issue
|
255 |
|
256 |
+
if IS_LORA and LORA_PATH.exists():
|
257 |
print("Loading VLM's custom text model π€")
|
258 |
+
text_model = PeftModel.from_pretrained(model=text_model, model_id=LORA_PATH, device_map=device)
|
259 |
+
text_model = text_model.merge_and_unload(safe_merge=True) # to avoid PEFT bug https://github.com/huggingface/transformers/issues/28515
|
260 |
+
else: print("VLM's custom text model isn't loaded π€")
|
261 |
|
262 |
print("Loading image adapter πΌοΈ")
|
263 |
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
|
|
|
462 |
parser.add_argument("--prompt", type=str, default="", help='Custom Prompt (optional, will override all other settings)')
|
463 |
parser.add_argument("--model", type=str, default=DEFAULT_MODEL_PATH,
|
464 |
help='Huggingface LLM repo (default: "unsloth/Meta-Llama-3.1-8B-bnb-4bit")')
|
465 |
+
parser.add_argument("--bf16", action="store_true", default=False, help="Use bfloat16 (default: NF4)")
|
466 |
+
parser.add_argument("--nolora", action="store_true", default=False, help="Disable VLM's custom text model (default: Enable)")
|
467 |
parser.add_argument("--tokens", type=int, default=300, help="Max tokens (default: 300)")
|
468 |
parser.add_argument("--topp", type=float, default=0.9, help="Top-P (default: 0.9)")
|
469 |
parser.add_argument("--temp", type=float, default=0.6, help="Temperature (default: 0.6)")
|
|
|
482 |
return False
|
483 |
|
484 |
def main():
|
485 |
+
global MODEL_PATH, IS_NF4, IS_LORA
|
486 |
args = parse_arguments()
|
487 |
input_paths = [Path(input_path) for input_path in args.input]
|
488 |
batch_size = args.bs
|
|
|
494 |
max_new_tokens = args.tokens
|
495 |
top_p = args.topp
|
496 |
temperature = args.temp
|
497 |
+
IS_NF4 = False if args.bf16 else True
|
498 |
+
IS_LORA = False if args.nolora else True
|
499 |
if is_valid_repo(args.model): MODEL_PATH = args.model
|
500 |
else: sys.exit(1)
|
501 |
models = load_models()
|