lorocksUMD commited on
Commit
431dab7
1 Parent(s): dfc1ad9

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +157 -0
script.py CHANGED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+
4
+ from transformers import AutoTokenizer
5
+ from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM
6
+ from llava.model.builder import load_pretrained_model
7
+ from llava.mm_utils import (
8
+ process_images,
9
+ tokenizer_image_token,
10
+ get_model_name_from_path,
11
+ )
12
+ from llava.constants import (
13
+ IMAGE_TOKEN_INDEX,
14
+ DEFAULT_IMAGE_TOKEN,
15
+ DEFAULT_IM_START_TOKEN,
16
+ DEFAULT_IM_END_TOKEN,
17
+ IMAGE_PLACEHOLDER,
18
+ )
19
+ from llava.conversation import conv_templates, SeparatorStyle
20
+
21
+ import torch
22
+ import requests
23
+ from PIL import Image
24
+ from io import BytesIO
25
+ import re
26
+
27
+
28
+
29
+ """
30
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
31
+ """
32
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
33
+
34
+ # Functions for inference
35
+ def image_parser(args):
36
+ out = args.image_file.split(args.sep)
37
+ return out
38
+
39
+
40
+ def load_image(image_file):
41
+ if image_file.startswith("http") or image_file.startswith("https"):
42
+ response = requests.get(image_file)
43
+ image = Image.open(BytesIO(response.content)).convert("RGB")
44
+ else:
45
+ image = Image.open(image_file).convert("RGB")
46
+ return image
47
+
48
+
49
+ def load_images(image_files):
50
+ out = []
51
+ for image_file in image_files:
52
+ image = load_image(image_file)
53
+ out.append(image)
54
+ return out
55
+
56
+ model_path = "liuhaotian/llava-v1.6-mistral-7b"
57
+ model_name = get_model_name_from_path(model_path)
58
+ # tokenizer = AutoTokenizer.from_pretrained(model_path)
59
+ # model = LlavaMistralForCausalLM.from_pretrained(
60
+ # model_path,
61
+ # low_cpu_mem_usage=True,
62
+ # # offload_folder="/content/sample_data"
63
+ # )
64
+
65
+ prompt = "What are the things I should be cautious about when I visit here?"
66
+ image_file = "Great-Room-4.jpg"
67
+
68
+ args = type('Args', (), {
69
+ "model_path": model_path,
70
+ "model_base": None,
71
+ "model_name": get_model_name_from_path(model_path),
72
+ "query": prompt,
73
+ "conv_mode": None,
74
+ "image_file": image_file,
75
+ "sep": ",",
76
+ "temperature": 0,
77
+ "top_p": None,
78
+ "num_beams": 1,
79
+ "max_new_tokens": 512
80
+ })()
81
+
82
+
83
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
84
+ model_path, None, model_name, device_map="cpu"
85
+ )
86
+
87
+ qs = args.query
88
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
89
+ if IMAGE_PLACEHOLDER in qs:
90
+ if model.config.mm_use_im_start_end:
91
+ qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
92
+ else:
93
+ qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
94
+ else:
95
+ if model.config.mm_use_im_start_end:
96
+ qs = image_token_se + "\n" + qs
97
+ else:
98
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
99
+
100
+ if "llama-2" in model_name.lower():
101
+ conv_mode = "llava_llama_2"
102
+ elif "mistral" in model_name.lower():
103
+ conv_mode = "mistral_instruct"
104
+ elif "v1.6-34b" in model_name.lower():
105
+ conv_mode = "chatml_direct"
106
+ elif "v1" in model_name.lower():
107
+ conv_mode = "llava_v1"
108
+ elif "mpt" in model_name.lower():
109
+ conv_mode = "mpt"
110
+ else:
111
+ conv_mode = "llava_v0"
112
+
113
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
114
+ print(
115
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
116
+ conv_mode, args.conv_mode, args.conv_mode
117
+ )
118
+ )
119
+ else:
120
+ args.conv_mode = conv_mode
121
+
122
+ conv = conv_templates[args.conv_mode].copy()
123
+ conv.append_message(conv.roles[0], qs)
124
+ conv.append_message(conv.roles[1], None)
125
+ prompt = conv.get_prompt()
126
+
127
+ image_files = image_parser(args)
128
+ images = load_images(image_files)
129
+ image_sizes = [x.size for x in images]
130
+ images_tensor = process_images(
131
+ images,
132
+ image_processor,
133
+ model.config
134
+ ).to(model.device, dtype=torch.float16)
135
+
136
+ input_ids = (
137
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
138
+ .unsqueeze(0)
139
+ # .cuda()
140
+ )
141
+
142
+ with torch.inference_mode():
143
+ output_ids = model.generate(
144
+ input_ids,
145
+ images=images_tensor,
146
+ image_sizes=image_sizes,
147
+ do_sample=True if args.temperature > 0 else False,
148
+ temperature=args.temperature,
149
+ top_p=args.top_p,
150
+ num_beams=args.num_beams,
151
+ max_new_tokens=args.max_new_tokens,
152
+ use_cache=True,
153
+ )
154
+
155
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
156
+ print(outputs)
157
+ # End Llava inference