Veda0718 commited on
Commit
589dc09
·
verified ·
1 Parent(s): 3e77724

Delete LLaVA-Med

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LLaVA-Med/.gitignore +0 -3
  2. LLaVA-Med/CODE_OF_CONDUCT.md +0 -9
  3. LLaVA-Med/LICENSE +0 -62
  4. LLaVA-Med/README.md +0 -260
  5. LLaVA-Med/SECURITY.md +0 -41
  6. LLaVA-Med/SUPPORT.md +0 -25
  7. LLaVA-Med/bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl +0 -3
  8. LLaVA-Med/data/eval/llava_med_eval_qa50_qa.jsonl +0 -0
  9. LLaVA-Med/docs/llava_med_performance.md +0 -31
  10. LLaVA-Med/download_data.sh +0 -35
  11. LLaVA-Med/images/llava_logo.png +0 -0
  12. LLaVA-Med/images/llava_med_chat.png +0 -0
  13. LLaVA-Med/images/llava_med_chat_example1.png +0 -0
  14. LLaVA-Med/images/llava_med_chat_example2.png +0 -0
  15. LLaVA-Med/images/llava_med_dataset.png +0 -0
  16. LLaVA-Med/images/llava_med_logo.png +0 -0
  17. LLaVA-Med/images/llava_med_pipeline.png +0 -0
  18. LLaVA-Med/images/llava_med_vqa.png +0 -0
  19. LLaVA-Med/llava/__init__.py +0 -0
  20. LLaVA-Med/llava/constants.py +0 -13
  21. LLaVA-Med/llava/conversation.py +0 -439
  22. LLaVA-Med/llava/eval/eval_multimodal_chat_gpt_score.py +0 -112
  23. LLaVA-Med/llava/eval/llm.py +0 -134
  24. LLaVA-Med/llava/eval/model_vqa.py +0 -109
  25. LLaVA-Med/llava/eval/run_llava.py +0 -145
  26. LLaVA-Med/llava/eval/summarize_gpt_review.py +0 -47
  27. LLaVA-Med/llava/eval/util.py +0 -9
  28. LLaVA-Med/llava/mm_utils.py +0 -110
  29. LLaVA-Med/llava/model/__init__.py +0 -1
  30. LLaVA-Med/llava/model/builder.py +0 -83
  31. LLaVA-Med/llava/model/builders.py +0 -152
  32. LLaVA-Med/llava/model/language_model/llava_mistral.py +0 -143
  33. LLaVA-Med/llava/model/llava_arch.py +0 -309
  34. LLaVA-Med/llava/model/multimodal_encoder/builder.py +0 -9
  35. LLaVA-Med/llava/model/multimodal_encoder/clip_encoder.py +0 -78
  36. LLaVA-Med/llava/model/multimodal_projector/builder.py +0 -51
  37. LLaVA-Med/llava/serve/__init__.py +0 -0
  38. LLaVA-Med/llava/serve/cli.py +0 -125
  39. LLaVA-Med/llava/serve/controller.py +0 -298
  40. LLaVA-Med/llava/serve/examples/bio_patch.png +0 -0
  41. LLaVA-Med/llava/serve/examples/extreme_ironing.jpg +0 -0
  42. LLaVA-Med/llava/serve/examples/med_img_1.png +0 -0
  43. LLaVA-Med/llava/serve/examples/synpic32933.jpg +0 -0
  44. LLaVA-Med/llava/serve/examples/synpic42202.jpg +0 -0
  45. LLaVA-Med/llava/serve/examples/waterview.jpg +0 -0
  46. LLaVA-Med/llava/serve/examples/xy_chromosome.jpg +0 -0
  47. LLaVA-Med/llava/serve/gradio_web_server.py +0 -477
  48. LLaVA-Med/llava/serve/model_worker.py +0 -285
  49. LLaVA-Med/llava/serve/register_worker.py +0 -26
  50. LLaVA-Med/llava/serve/test_message.py +0 -62
LLaVA-Med/.gitignore DELETED
@@ -1,3 +0,0 @@
1
- __pycache__
2
- *.pyc
3
- *.egg-info
 
 
 
 
LLaVA-Med/CODE_OF_CONDUCT.md DELETED
@@ -1,9 +0,0 @@
1
- # Microsoft Open Source Code of Conduct
2
-
3
- This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
-
5
- Resources:
6
-
7
- - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
- - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
- - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/LICENSE DELETED
@@ -1,62 +0,0 @@
1
- MICROSOFT RESEARCH LICENSE TERMS
2
-
3
- IF YOU LIVE IN THE UNITED STATES, PLEASE READ THE “BINDING ARBITRATION AND CLASS ACTION WAIVER” SECTION BELOW. IT AFFECTS HOW DISPUTES ARE RESOLVED.
4
-
5
- These license terms are an agreement between you and Microsoft Corporation (or one of its affiliates). They apply to the source code, object code, machine learning models, or data (collectively “Materials”) that accompany this license. IF YOU COMPLY WITH THESE LICENSE TERMS, YOU HAVE THE RIGHTS BELOW. BY USING THE MATERIALS, YOU ACCEPT THESE TERMS.
6
-
7
- 1) INSTALLATION AND USE RIGHTS TO THE MATERIALS.
8
-
9
- Subject to the terms of this agreement, you have the below rights, if applicable, to use the Materials solely for non-commercial, non-revenue generating, research purposes:
10
-
11
- a) Source Code. If source code is included, you may use and modify the source code, but you may not distribute the source code.
12
- b) Object Code. If object code is included, you may use the object code, but you may not distribute the object code.
13
- c) Models. If machine learning model(s) are included, you may use the model(s), but you may not distribute the models.
14
- d) Data. If data is included, you may use and modify the data, but your use and modification must be consistent with the consent under which the data was provided and/or gathered and you may not distribute the data or your modifications to the data.
15
-
16
- 2) SCOPE OF LICENSE. The Materials are licensed, not sold. Microsoft reserves all other rights. Unless applicable law gives you more rights despite this limitation, you will not (and have no right to):
17
-
18
- a) work around any technical limitations in the Materials that only allow you to use it in certain ways;
19
- b) reverse engineer, decompile or disassemble the Materials;
20
- c) remove, minimize, block, or modify any notices of Microsoft or its suppliers in the Materials;
21
- d) use the Materials in any way that is against the law or to create or propagate malware; or
22
- e) share, publish, distribute or lend the Materials, provide the Materials as a stand-alone hosted solution for others to use, or transfer the Materials or this agreement to any third party.
23
-
24
- 3) PERSONAL DATA. If the data (set forth in Section 1(c) above) includes or is found to include any data that enables any ability to identify an individual (“Personal Data”), you will not use such Personal Data for any purpose other than was authorized and consented to by the data subject/research participant. You will not use Personal Data to contact any person. You will keep Personal Data in strict confidence. You will not share any Personal Data that is collected or in your possession with any third party for any reason and as required under the original consent agreement. Further, you will destroy the Personal Data and any backup or copies, immediately upon the completion of your research.
25
-
26
- 4) LICENSE TO MICROSOFT. Notwithstanding the limitations in Section 1, you may distribute your modifications back to Microsoft, and if you do provide Microsoft with modifications of the Materials, you hereby grant Microsoft, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer such modifications and derivatives for any purpose.
27
-
28
- 5) PUBLICATION. You may publish (or present papers or articles) on your results from using the Materials provided that no material or substantial portion of the Materials is included in any such publication or presentation.
29
-
30
- 6) FEEDBACK. Any feedback about the Materials provided by you to us is voluntarily given, and Microsoft shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the feedback is designated by you as confidential. Such feedback shall be considered a contribution and licensed to Microsoft under the terms of Section 4 above.
31
-
32
- 7) COMPLIANCE WITH TRADE LAWS. You acknowledge that the Materials may be subject to applicable trade laws in one or more countries. You will comply with all relevant laws and regulations applicable to the import or export of the Materials, including but not limited to, trade laws such as the U.S. Export Administration Regulations or other end-user, end use, and destination restrictions by the U.S. and other governments, as well as sanctions regulations administered by the U.S. Office of Foreign Assets Control. Microsoft may suspend or terminate the agreement immediately to the extent that Microsoft reasonably concludes that continued performance would violate trade laws or put it at risk of becoming subject to sanctions or penalties under trade laws. For additional information, see www.microsoft.com/exporting.
33
-
34
- 8) SUPPORT SERVICES. Microsoft is not obligated under this agreement to provide any support services for the Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
35
-
36
- 9) BINDING ARBITRATION AND CLASS ACTION WAIVER. This Section applies if you live in (or, if a business, your principal place of business is in) the United States. If you and Microsoft have a dispute, you and Microsoft agree to try for 60 days to resolve it informally. If you and Microsoft can’t, you and Microsoft agree to binding individual arbitration before the American Arbitration Association under the Federal Arbitration Act (“FAA”), and not to sue in court in front of a judge or jury. Instead, a neutral arbitrator will decide. Class action lawsuits, class-wide arbitrations, private attorney-general actions, and any other proceeding where someone acts in a representative capacity are not allowed; nor is combining individual proceedings without the consent of all parties. The complete Arbitration Agreement contains more terms and is at aka.ms/arb-agreement-1. You and Microsoft agree to these terms.
37
-
38
- 10) ENTIRE AGREEMENT. This agreement, and any other terms Microsoft may provide for supplements, updates, or third-party applications, is the entire agreement for the Materials.
39
-
40
- 11) APPLICABLE LAW AND PLACE TO RESOLVE DISPUTES. If you acquired the Materials in the United States or Canada, the laws of the state or province where you live (or, if a business, where your principal place of business is located) govern the interpretation of this agreement, claims for its breach, and all other claims (including consumer protection, unfair competition, and tort claims), regardless of conflict of laws principles, except that the FAA governs everything related to arbitration. If you acquired the Materials in any other country, its laws apply, except that the FAA governs everything related to arbitration. If U.S. federal jurisdiction exists, you and Microsoft consent to exclusive jurisdiction and venue in the federal court in King County, Washington for all disputes heard in court (excluding arbitration). If not, you and Microsoft consent to exclusive jurisdiction and venue in the Superior Court of King County, Washington for all disputes heard in court (excluding arbitration).
41
-
42
- 12) CONSUMER RIGHTS; REGIONAL VARIATIONS. This agreement describes certain legal rights. You may have other rights, including consumer rights, under the laws of your state, province, or country. Separate and apart from your relationship with Microsoft, you may also have rights with respect to the party from which you acquired the Materials. This agreement does not change those other rights if the laws of your state, province, or country do not permit it to do so. For example, if you acquired the Materials in one of the below regions, or mandatory country law applies, then the following provisions apply to you:
43
-
44
- a) Australia. You have statutory guarantees under the Australian Consumer Law and nothing in this agreement is intended to affect those rights.
45
-
46
- b) Canada. If you acquired this software in Canada, you may stop receiving updates by turning off the automatic update feature, disconnecting your device from the Internet (if and when you re-connect to the Internet, however, the Materials will resume checking for and installing updates), or uninstalling the Materials. The product documentation, if any, may also specify how to turn off updates for your specific device or software.
47
-
48
- c) Germany and Austria.
49
-
50
- i. Warranty. The properly licensed software will perform substantially as described in any Microsoft materials that accompany the Materials. However, Microsoft gives no contractual guarantee in relation to the licensed software.
51
-
52
- ii. Limitation of Liability. In case of intentional conduct, gross negligence, claims based on the Product Liability Act, as well as, in case of death or personal or physical injury, Microsoft is liable according to the statutory law.
53
-
54
- Subject to the foregoing clause (ii), Microsoft will only be liable for slight negligence if Microsoft is in breach of such material contractual obligations, the fulfillment of which facilitate the due performance of this agreement, the breach of which would endanger the purpose of this agreement and the compliance with which a party may constantly trust in (so-called "cardinal obligations"). In other cases of slight negligence, Microsoft will not be liable for slight negligence.
55
-
56
- 13) DISCLAIMER OF WARRANTY. THE MATERIALS ARE LICENSED “AS IS.” YOU BEAR THE RISK OF USING THEM. MICROSOFT GIVES NO EXPRESS WARRANTIES, GUARANTEES, OR CONDITIONS. TO THE EXTENT PERMITTED UNDER APPLICABLE LAWS, MICROSOFT EXCLUDES ALL IMPLIED WARRANTIES, INCLUDING MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT.
57
-
58
- 14) LIMITATION ON AND EXCLUSION OF DAMAGES. IF YOU HAVE ANY BASIS FOR RECOVERING DAMAGES DESPITE THE PRECEDING DISCLAIMER OF WARRANTY, YOU CAN RECOVER FROM MICROSOFT AND ITS SUPPLIERS ONLY DIRECT DAMAGES UP TO U.S. $5.00. YOU CANNOT RECOVER ANY OTHER DAMAGES, INCLUDING CONSEQUENTIAL, LOST PROFITS, SPECIAL, INDIRECT OR INCIDENTAL DAMAGES.
59
-
60
- This limitation applies to (a) anything related to the Materials, services, content (including code) on third party Internet sites, or third party applications; and (b) claims for breach of contract, warranty, guarantee, or condition; strict liability, negligence, or other tort; or any other claim; in each case to the extent permitted by applicable law.
61
-
62
- It also applies even if Microsoft knew or should have known about the possibility of the damages. The above limitation or exclusion may not apply to you because your state, province, or country may not allow the exclusion or limitation of incidental, consequential, or other damages.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/README.md DELETED
@@ -1,260 +0,0 @@
1
- # LLaVA-Med: Large Language and Vision Assistant for Biomedicine
2
-
3
- *Visual instruction tuning towards building large language and vision models with GPT-4 level capabilities in the biomedicine space.*
4
-
5
- [[Paper, NeurIPS 2023 Datasets and Benchmarks Track (Spotlight)](https://arxiv.org/abs/2306.00890)]
6
-
7
- **LLaVA-Med: Training a Large Language-and-Vision Assistant for Biomedicine in One Day** <br>
8
-
9
- [Chunyuan Li*](https://chunyuan.li/), [Cliff Wong*](https://scholar.google.com/citations?user=Sl05ifcAAAAJ&hl=en), [Sheng Zhang*](https://scholar.google.com/citations?user=-LVEXQ8AAAAJ&hl=en), [Naoto Usuyama](https://www.microsoft.com/en-us/research/people/naotous/), [Haotian Liu](https://hliu.cc), [Jianwei Yang](https://jwyang.github.io/), [Tristan Naumann](https://scholar.google.com/citations?user=cjlSeqwAAAAJ&hl=en), [Hoifung Poon](https://scholar.google.com/citations?user=yqqmVbkAAAAJ&hl=en), [Jianfeng Gao](https://scholar.google.com/citations?user=CQ1cqKkAAAAJ&hl=en) (*Equal Contribution)
10
-
11
- <p align="center">
12
- <img src="images/llava_med_logo.png" width="50%"> <br>
13
-
14
- *Generated by <a href="https://gligen.github.io/">GLIGEN</a> using the grounded inpainting mode, with three boxes: ``white doctor coat``, ``stethoscope``, ``white doctor hat with a red cross sign``.*
15
-
16
- </p>
17
-
18
-
19
- ## Release
20
-
21
- - [May 13, 2024] 🔥LLaVA-Med v1.5 is out! It is not only significantly better (see the [evaluation results](docs/llava_med_performance.md#llava-med-15-performance).) but also much easier to use: no more *delta* weights! Now you can directly load our model from the [🤗 Hub](https://huggingface.co/microsoft/llava-med-v1.5-mistral-7b). The original LLaVA-Med (i.e., v1.0.0) codebase has been moved to [Archive](#archive).
22
- - [Nov 8, 2023] LLaVA-Med is open-sourced under the MSR release policy. Huge thanks to commitment of the team, and patience of the community.
23
- - [Sept, 2023] LLaVA-Med is accepted in NeurIPS 2023 Datasets and Benchmarks Track, as a spotlight presentation.
24
- - [June 1, 2023] 🔥 We released **LLaVA-Med: Large Language and Vision Assistant for Biomedicine**, a step towards building biomedical domain large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2306.00890)
25
-
26
- <p align="center">
27
- <img src="images/llava_med_pipeline.png" width="90%"> <br>
28
-
29
- *LLaVA-Med was initialized with the general-domain LLaVA and then continuously trained in a curriculum learning fashion (first biomedical concept alignment then full-blown instruction-tuning). We evaluated LLaVA-Med on standard visual conversation and question answering tasks.*
30
- </p>
31
-
32
- [![Code License](https://img.shields.io/badge/Code%20License-Microsoft%20Research-red)](Research%20License.docx)
33
- [![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](https://creativecommons.org/licenses/by-nc/4.0/deed.en)
34
- **Usage and License Notices**: The data, code, and model checkpoints are intended and licensed for research use only. They are also subject to additional restrictions dictated by the Terms of Use: LLaMA, Vicuna and GPT-4 respectively. The data is made available under CC BY NC 4.0. The data, code, and model checkpoints may be used for non-commercial purposes and any models trained using the dataset should be used only for research purposes. It is expressly prohibited for models trained on this data to be used in clinical care or for any clinical decision making purposes.
35
-
36
- ## Contents
37
-
38
- - [Install](#install)
39
- - [Model Download](#model-download)
40
- - [Serving](#serving)
41
- - [Evaluation](#evaluation)
42
- - [Data Download](#data-download)
43
- - [Archive](#archive)
44
- - [Model Description](#model-description)
45
-
46
- ## Install
47
-
48
- 1. Clone this repository and navigate to LLaVA-Med folder
49
- ```bash
50
- https://github.com/microsoft/LLaVA-Med.git
51
- cd LLaVA-Med
52
- ```
53
-
54
- 2. Install Package: Create conda environment
55
-
56
- ```Shell
57
- conda create -n llava-med python=3.10 -y
58
- conda activate llava-med
59
- pip install --upgrade pip # enable PEP 660 support
60
- pip install -e .
61
- ```
62
-
63
- ## Model Download
64
-
65
-
66
- Model Descriptions | 🤗 Huggingface Hub |
67
- | --- | ---: |
68
- | LLaVA-Med v1.5 | [microsoft/llava-med-v1.5-mistral-7b](https://huggingface.co/microsoft/llava-med-v1.5-mistral-7b) |
69
-
70
-
71
-
72
- ## Serving
73
-
74
- ### Web UI
75
-
76
- #### Launch a controller
77
- ```Shell
78
- python -m llava.serve.controller --host 0.0.0.0 --port 10000
79
- ```
80
-
81
- #### Launch a model worker
82
- ```Shell
83
- python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path microsoft/llava-med-v1.5-mistral-7b --multi-modal
84
- ```
85
- Wait until the process finishes loading the model and you see "Uvicorn running on ...".
86
-
87
- #### Launch a model worker (Multiple GPUs, when GPU VRAM <= 24GB)
88
-
89
- If your the VRAM of your GPU is less than 24GB (e.g., RTX 3090, RTX 4090, etc.), you may try running it with multiple GPUs.
90
-
91
- ```Shell
92
- python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path microsoft/llava-med-v1.5-mistral-7b --multi-modal --num-gpus 2
93
- ```
94
- Wait until the process finishes loading the model and you see "Uvicorn running on ...".
95
-
96
-
97
- #### Send a test message
98
- ```Shell
99
- python -m llava.serve.test_message --model-name llava-med-v1.5-mistral-7b --controller http://localhost:10000
100
- ```
101
-
102
- #### Launch a gradio web server.
103
- ```Shell
104
- python -m llava.serve.gradio_web_server --controller http://localhost:10000
105
- ```
106
- #### You can open your browser and chat with a model now.
107
-
108
-
109
- ## Evaluation
110
-
111
- ### Medical Visual Chat (GPT-assisted Evaluation)
112
-
113
- Our GPT-assisted evaluation pipeline for multimodal modeling is provided for a comprehensive understanding of the capabilities of vision-language models. Please see our paper for more details.
114
-
115
- #### 1. Azure OpenAI Connection Info.
116
-
117
- Open [llava/eval/llm.py](llava/eval/llm.py?plain=1#L33) and insert your Azure OpenAI Endpoint and API KEY
118
- ```Shell
119
- openai_cxn_dict = {
120
- 'default': {
121
- 'endpoint': "INSERT YOUR AZURE OPENAI ENDPOINT HERE",
122
- 'api_key': "INSERT YOUR AZURE OPENAI API KEY HERE",
123
- },
124
- }
125
- ```
126
- * GPT-4 inference was only tested using Azure OpenAI API. If you are using OpenAI API, you need to replace [llava/eval/llm.py (line 55)](llava/eval/llm.py?plain=1#L55) AsyncAzureOpenAI with AsyncOpenAI.
127
-
128
- #### 2. Deployment ID
129
- In [llava/eval/eval_multimodal_chat_gpt_score.py (line 55)](llava/eval/eval_multimodal_chat_gpt_score.py?plain=1#L55), replace with your GPT-4 model deployment id if necessary:
130
-
131
- #### 3. Download Images
132
-
133
- ```Shell
134
- wget https://hanoverprod.z21.web.core.windows.net/med_llava/multimodal_chat_eval/llava_med_test_image_urls.jsonl -P data/
135
- python llava/data/download_images.py \
136
- --input_path data/llava_med_test_image_urls.jsonl \
137
- --pmc_output_path data/pmc \
138
- --images_output_path data/images
139
- ```
140
-
141
- #### 4. Multimodal Chat Inference
142
- In our case, [`llava_med_eval_qa50_qa.jsonl`](/data/eval/llava_med_eval_qa50_qa.jsonl) contains the questions, context (captions and inline-mentions) and responses generated by text-only GPT-4 (0314), which we treat as ground truth.
143
-
144
- ```Shell
145
- PYTHONPATH=. python llava/eval/model_vqa.py \
146
- --conv-mode mistral_instruct \
147
- --model-path microsoft/llava-med-v1.5-mistral-7b \
148
- --question-file data/eval/llava_med_eval_qa50_qa.jsonl \
149
- --image-folder data/images \
150
- --answers-file /path/to/answer-file.jsonl \
151
- --temperature 0.0
152
- ```
153
-
154
- #### 5. GPT-4 Evaluation of the Generated Answers
155
-
156
- ```Shell
157
- python llava/eval/eval_multimodal_chat_gpt_score.py \
158
- --answers-file /path/to/answer-file.jsonl \
159
- --question-file data/eval/llava_med_eval_qa50_qa.jsonl \
160
- --scores-file /path/to/scores-file.jsonl
161
- ```
162
-
163
- #### 6. Summarize the Evaluation Results
164
-
165
- ```Shell
166
- python llava/eval/summarize_gpt_review.py \
167
- --scores-file /path/to/scores-file.jsonl
168
- ```
169
-
170
- ## Data Download
171
-
172
- ### LLaVA-Med Dataset
173
-
174
- <p align="center">
175
- <img src="images/llava_med_dataset.png" width="90%"> <br>
176
-
177
- *The data statistics of biomedical multimodal instruction-following data: (a,b) The root verb-noun pairs of instruction and responses, where the inner circle of the plot represents the root verb of the output response, and the outer circle represents the direct nouns. (c) The distribution of images and QA pairs on the five domains, one image is shown per domain.*
178
- </p>
179
-
180
- ### Data Download
181
- | Alignment data files | Size |
182
- | --- | ---: |
183
- | [llava_med_alignment_500k.json](https://hanoverprod.z21.web.core.windows.net/med_llava/alignment/llava_med_alignment_500k.json) | 341.52 MiB |
184
-
185
- | Instruction-Tuning data files | Size |
186
- | --- | ---: |
187
- | [llava_med_instruct_10k.json](https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_10k.json) | 19.24 MiB |
188
- | [llava_med_instruct_60k.json](https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_60k.json) | 84.65 MiB |
189
- | [llava_med_instruct_60k_inline_mention.json](https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_60k_inline_mention.json) | 83.61 MiB |
190
- | [llava_med_instruct_fig_captions.json](https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_fig_captions.json) | 161.39 MiB |
191
-
192
- | Evaluation files | Size |
193
- | --- | ---: |
194
- | [llava_med_eval_qa50_qa.jsonl](https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_eval_qa50_qa.jsonl) | 256.18 KiB |
195
- | [llava_med_eval_qa50_fig_captions.json](https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_eval_qa50_fig_captions.json) | 51.82 KiB |
196
- | [llava_med_qa50_instruct_caption_in_text_cleaned-60k-3epoch.json](https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_qa50_instruct_caption_in_text_cleaned-60k-3epoch.json) | 100.97 KiB |
197
-
198
- | Image URLS | Size |
199
- | --- | ---: |
200
- | [llava_med_image_urls.jsonl](https://hanoverprod.z21.web.core.windows.net/med_llava/llava_med_image_urls.jsonl) | 122.82 MiB |
201
-
202
- [download_images.py](https://github.com/microsoft/LLaVA-Med/blob/v1.0.0/llava/data/download_images.py) is used to download the PMC articles using the above image_urls file and extract the images
203
-
204
- To download our langauge-image multimodal instruction-folllowing dataset, please run the following script:
205
- ```bash
206
- sh download_data.sh
207
- ```
208
-
209
-
210
- ## Archive
211
-
212
- - [LLaVA-Med v1.0](https://github.com/microsoft/LLaVA-Med/tree/v1.0.0)
213
-
214
- ## Model Description
215
-
216
- Large Language and Vision Assistant for bioMedicine (i.e., “LLaVA-Med”) is a large language and vision model trained using a curriculum learning method for adapting LLaVA to the biomedical domain. It is an open-source release intended for research use only to facilitate reproducibility of the corresponding paper which claims improved performance for open-ended biomedical questions answering tasks, including common visual question answering (VQA) benchmark datasets such as PathVQA and VQA-RAD.
217
-
218
- ### Model Uses
219
-
220
- #### Intended Use
221
-
222
- The data, code, and model checkpoints are intended to be used solely for (I) future research on visual-language processing and (II) reproducibility of the experimental results reported in the reference paper. The data, code, and model checkpoints are not intended to be used in clinical care or for any clinical decision making purposes.
223
-
224
- #### Primary Intended Use
225
-
226
- The primary intended use is to support AI researchers reproducing and building on top of this work. LLaVA-Med and its associated models should be helpful for exploring various biomedical vision-language processing (VLP ) and vision question answering (VQA) research questions.
227
-
228
- #### Out-of-Scope Use
229
-
230
- **Any** deployed use case of the model --- commercial or otherwise --- is out of scope. Although we evaluated the models using a broad set of publicly-available research benchmarks, the models and evaluations are intended *for research use only* and not intended for deployed use cases. Please refer to [the associated paper](https://aka.ms/llava-med) for more details.
231
-
232
- ### Data
233
-
234
- This model builds upon [PMC-15M dataset](https://aka.ms/biomedclip-paper), which is a large-scale parallel image-text dataset for biomedical vision-language processing. It contains 15 million figure-caption pairs extracted from biomedical research articles in PubMed Central. It covers a diverse range of biomedical image types, such as microscopy, radiography, histology, and more.
235
-
236
- ### Limitations
237
-
238
- This model was developed using English corpora, and thus may be considered English-only. This model is evaluated on a narrow set of biomedical benchmark tasks, described in [LLaVA-Med paper](https://aka.ms/llava-med). As such, it is not suitable for use in any clinical setting. Under some conditions, the model may make inaccurate predictions and display limitations, which may require additional mitigation strategies. In particular, this model is likely to carry many of the limitations of the model from which it is derived, [LLaVA](https://llava-vl.github.io/).
239
-
240
- Further, this model was developed in part using the [PMC-15M](https://aka.ms/biomedclip-paper) dataset. The figure-caption pairs that make up this dataset may contain biases reflecting the current practice of academic publication. For example, the corresponding papers may be enriched for positive findings, contain examples of extreme cases, and otherwise reflect distributions that are not representative of other sources of biomedical data.
241
-
242
- ## Acknowledgement
243
-
244
- If you find LLaVA-Med useful for your your research and applications, please cite using this BibTeX:
245
-
246
- ```bibtex
247
- @article{li2023llavamed,
248
- title={Llava-med: Training a large language-and-vision assistant for biomedicine in one day},
249
- author={Li, Chunyuan and Wong, Cliff and Zhang, Sheng and Usuyama, Naoto and Liu, Haotian and Yang, Jianwei and Naumann, Tristan and Poon, Hoifung and Gao, Jianfeng},
250
- journal={arXiv preprint arXiv:2306.00890},
251
- year={2023}
252
- }
253
- ```
254
-
255
-
256
- ## Related Projects
257
-
258
- - [LLaVA](https://llava-vl.github.io/)
259
- - [BiomedCLIP](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224)
260
- - [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/SECURITY.md DELETED
@@ -1,41 +0,0 @@
1
- <!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
2
-
3
- ## Security
4
-
5
- Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6
-
7
- If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8
-
9
- ## Reporting Security Issues
10
-
11
- **Please do not report security vulnerabilities through public GitHub issues.**
12
-
13
- Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14
-
15
- If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16
-
17
- You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18
-
19
- Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20
-
21
- * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22
- * Full paths of source file(s) related to the manifestation of the issue
23
- * The location of the affected source code (tag/branch/commit or direct URL)
24
- * Any special configuration required to reproduce the issue
25
- * Step-by-step instructions to reproduce the issue
26
- * Proof-of-concept or exploit code (if possible)
27
- * Impact of the issue, including how an attacker might exploit the issue
28
-
29
- This information will help us triage your report more quickly.
30
-
31
- If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32
-
33
- ## Preferred Languages
34
-
35
- We prefer all communications to be in English.
36
-
37
- ## Policy
38
-
39
- Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40
-
41
- <!-- END MICROSOFT SECURITY.MD BLOCK -->
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/SUPPORT.md DELETED
@@ -1,25 +0,0 @@
1
- # TODO: The maintainer of this repo has not yet edited this file
2
-
3
- **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4
-
5
- - **No CSS support:** Fill out this template with information about how to file issues and get help.
6
- - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
7
- - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
8
-
9
- *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10
-
11
- # Support
12
-
13
- ## How to file issues and get help
14
-
15
- This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16
- issues before filing new issues to avoid duplicates. For new issues, file your bug or
17
- feature request as a new Issue.
18
-
19
- For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20
- FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21
- CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22
-
23
- ## Microsoft Support Policy
24
-
25
- Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0f0323de1ff1fdf8383e79bdad1283516a4c05a6fd2b44a363bf4e059422305b
3
- size 69084267
 
 
 
 
LLaVA-Med/data/eval/llava_med_eval_qa50_qa.jsonl DELETED
The diff for this file is too large to render. See raw diff
 
LLaVA-Med/docs/llava_med_performance.md DELETED
@@ -1,31 +0,0 @@
1
- ## LLaVA-Med-1.5 Performance
2
-
3
- <p align="center">
4
- <img src="https://hanoverprod.z21.web.core.windows.net/med_llava/web/llava-med_1.5_eval.png" width="90%"> <br>
5
-
6
- *Performance comparison of mulitmodal chat instruction-following abilities, measured by the relative score via language GPT-4 evaluation.*
7
- </p>
8
-
9
-
10
- ## LLaVA-Med-1.0 Performance
11
-
12
- <p align="center">
13
- <img src="../images/llava_med_chat_example1.png" width="90%"> <br>
14
-
15
- *Example 1: comparison of medical visual chat. The language-only GPT-4 is considered as the performance upper bound, as the golden captions and inline mentions are fed into GPT-4 as the context, without requiring the model to understand the raw image.*
16
- </p>
17
-
18
- <p align="center">
19
- <img src="../images/llava_med_chat_example2.png" width="90%"> <br>
20
-
21
- *Example 2: comparison of medical visual chat. LLaVA tends to halluciate or refuse to provide domain-specific knowledgable response.*
22
- </p>
23
-
24
-
25
- <p align="center">
26
- <img src="../images/llava_med_vqa.png" width="90%"> <br>
27
-
28
- *Performance comparison of fine-tuned LLaVA-Med on established Medical QVA datasets.*
29
- </p>
30
-
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/download_data.sh DELETED
@@ -1,35 +0,0 @@
1
- #!/bin/bash
2
-
3
- mkdir data/alignment
4
- cd data/alignment
5
-
6
- wget https://hanoverprod.z21.web.core.windows.net/med_llava/alignment/llava_med_alignment_500k.json
7
-
8
- cd ..
9
-
10
- mkdir instruct
11
- cd instruct
12
-
13
- wget https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_10k.json
14
- wget https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_60k.json
15
- wget https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_60k_inline_mention.json
16
- wget https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_fig_captions.json
17
- cd ..
18
-
19
- mkdir eval
20
- cd eval
21
-
22
- wget https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_eval_qa50_qa.jsonl
23
- wget https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_eval_qa50_fig_captions.json
24
- wget https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_qa50_instruct_caption_in_text_cleaned-60k-3epoch.json
25
-
26
- cd ..
27
-
28
- wget https://hanoverprod.z21.web.core.windows.net/med_llava/llava_med_image_urls.jsonl
29
- mkdir pmc_articles
30
- mkdir images
31
-
32
- cd ..
33
-
34
- pip install tqdm
35
- python llava/data/download_images.py --input_path data/llava_med_image_urls.jsonl --pmc_output_path data/pmc_articles/ --images_output_path data/images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/images/llava_logo.png DELETED
Binary file (268 kB)
 
LLaVA-Med/images/llava_med_chat.png DELETED
Binary file (63.1 kB)
 
LLaVA-Med/images/llava_med_chat_example1.png DELETED
Binary file (315 kB)
 
LLaVA-Med/images/llava_med_chat_example2.png DELETED
Binary file (279 kB)
 
LLaVA-Med/images/llava_med_dataset.png DELETED
Binary file (542 kB)
 
LLaVA-Med/images/llava_med_logo.png DELETED
Binary file (379 kB)
 
LLaVA-Med/images/llava_med_pipeline.png DELETED
Binary file (349 kB)
 
LLaVA-Med/images/llava_med_vqa.png DELETED
Binary file (126 kB)
 
LLaVA-Med/llava/__init__.py DELETED
File without changes
LLaVA-Med/llava/constants.py DELETED
@@ -1,13 +0,0 @@
1
- CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
- WORKER_HEART_BEAT_INTERVAL = 15
3
-
4
- LOGDIR = "."
5
-
6
- # Model Constants
7
- IGNORE_INDEX = -100
8
- IMAGE_TOKEN_INDEX = -200
9
- DEFAULT_IMAGE_TOKEN = "<image>"
10
- DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
- DEFAULT_IM_START_TOKEN = "<im_start>"
12
- DEFAULT_IM_END_TOKEN = "<im_end>"
13
- IMAGE_PLACEHOLDER = "<image-placeholder>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/conversation.py DELETED
@@ -1,439 +0,0 @@
1
- import dataclasses
2
- from enum import auto, Enum
3
- from typing import List, Tuple
4
- import base64
5
- from io import BytesIO
6
- from PIL import Image
7
-
8
-
9
- class SeparatorStyle(Enum):
10
- """Different separator style."""
11
- SINGLE = auto()
12
- TWO = auto()
13
- MPT = auto()
14
- PLAIN = auto()
15
- LLAMA_2 = auto()
16
- MISTRAL = auto()
17
-
18
-
19
- @dataclasses.dataclass
20
- class Conversation:
21
- """A class that keeps all conversation history."""
22
- system: str
23
- roles: List[str]
24
- messages: List[List[str]]
25
- offset: int
26
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
27
- sep: str = "###"
28
- sep2: str = None
29
- version: str = "Unknown"
30
-
31
- skip_next: bool = False
32
-
33
- def get_prompt(self):
34
- messages = self.messages
35
- if len(messages) > 0 and type(messages[0][1]) is tuple:
36
- messages = self.messages.copy()
37
- init_role, init_msg = messages[0].copy()
38
- init_msg = init_msg[0].replace("<image>", "").strip()
39
- if 'mmtag' in self.version:
40
- messages[0] = (init_role, init_msg)
41
- messages.insert(0, (self.roles[0], "<Image><image></Image>"))
42
- messages.insert(1, (self.roles[1], "Received."))
43
- else:
44
- messages[0] = (init_role, "<image>\n" + init_msg)
45
-
46
- if self.sep_style == SeparatorStyle.SINGLE:
47
- ret = self.system + self.sep
48
- for role, message in messages:
49
- if message:
50
- if type(message) is tuple:
51
- message, _, _ = message
52
- ret += role + ": " + message + self.sep
53
- else:
54
- ret += role + ":"
55
- elif self.sep_style == SeparatorStyle.TWO:
56
- seps = [self.sep, self.sep2]
57
- ret = self.system + seps[0]
58
- for i, (role, message) in enumerate(messages):
59
- if message:
60
- if type(message) is tuple:
61
- message, _, _ = message
62
- sep = seps[i % 2]
63
- sep = "{0} ".format(self.sep2) if sep == self.sep2 else self.sep
64
- ret += role + ": " + message.strip() + sep
65
- else:
66
- ret += role + ":"
67
- ret = ret.strip()
68
- elif self.sep_style == SeparatorStyle.MPT:
69
- ret = self.system + self.sep
70
- for role, message in messages:
71
- if message:
72
- if type(message) is tuple:
73
- message, _, _ = message
74
- ret += role + message + self.sep
75
- else:
76
- ret += role
77
- elif self.sep_style == SeparatorStyle.LLAMA_2:
78
- wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
79
- wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
80
- ret = ""
81
-
82
- for i, (role, message) in enumerate(messages):
83
- if i == 0:
84
- assert message, "first message should not be none"
85
- assert role == self.roles[0], "first message should come from user"
86
- if message:
87
- if type(message) is tuple:
88
- message, _, _ = message
89
- if i == 0: message = wrap_sys(self.system) + message
90
- if i % 2 == 0:
91
- message = wrap_inst(message)
92
- ret += self.sep + message
93
- else:
94
- ret += " " + message + " " + self.sep2
95
- else:
96
- ret += ""
97
- ret = ret.lstrip(self.sep)
98
- elif self.sep_style == SeparatorStyle.PLAIN:
99
- seps = [self.sep, self.sep2]
100
- ret = self.system
101
- for i, (role, message) in enumerate(messages):
102
- if message:
103
- if type(message) is tuple:
104
- message, _, _ = message
105
- ret += message + seps[i % 2]
106
- else:
107
- ret += ""
108
- elif self.sep_style == SeparatorStyle.MISTRAL:
109
- # reference: https://docs.mistral.ai/models/
110
- wrap_sys = lambda msg: f"{msg}</s>"
111
- wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
112
- ret = ""
113
- for i, (role, message) in enumerate(messages):
114
- if i == 0:
115
- assert message, "first message should not be none"
116
- assert role == self.roles[0], "first message should come from user"
117
- if message:
118
- if type(message) is tuple:
119
- message, _, _ = message
120
- if i == 0: message = self.system + " " + message.strip()
121
- if i % 2 == 0:
122
- message = wrap_inst(message)
123
- ret += message
124
- else:
125
- ret += wrap_sys(message)
126
- else:
127
- ret += ""
128
- # wrap_sys = lambda msg: f"\n{msg}\n\n"
129
- # wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
130
- # ret = ""
131
- # for i, (role, message) in enumerate(messages):
132
- # if i == 0:
133
- # assert message, "first message should not be none"
134
- # assert role == self.roles[0], "first message should come from user"
135
- # if message:
136
- # if type(message) is tuple:
137
- # message, _, _ = message
138
- # if i == 0: message = wrap_sys(self.system) + message
139
- # if i % 2 == 0:
140
- # message = wrap_inst(message)
141
- # ret += message if i != 0 else self.sep + message
142
- # else:
143
- # # NOTE-JW: we need to add " " to strictly follow Mistral Instruction Format
144
- # ret += " " + message + " " + self.sep2
145
- # # ret += " " + wrap_sys(message)
146
- # else:
147
- # ret += ""
148
- else:
149
- raise ValueError(f"Invalid style: {self.sep_style}")
150
-
151
- return ret
152
-
153
- def append_message(self, role, message):
154
- self.messages.append([role, message])
155
-
156
- def get_images(self, return_pil=False):
157
- images = []
158
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
159
- if i % 2 == 0:
160
- if type(msg) is tuple:
161
- import base64
162
- from io import BytesIO
163
- from PIL import Image
164
- msg, image, image_process_mode = msg
165
- if image_process_mode == "Pad":
166
- def expand2square(pil_img, background_color=(122, 116, 104)):
167
- width, height = pil_img.size
168
- if width == height:
169
- return pil_img
170
- elif width > height:
171
- result = Image.new(pil_img.mode, (width, width), background_color)
172
- result.paste(pil_img, (0, (width - height) // 2))
173
- return result
174
- else:
175
- result = Image.new(pil_img.mode, (height, height), background_color)
176
- result.paste(pil_img, ((height - width) // 2, 0))
177
- return result
178
- image = expand2square(image)
179
- elif image_process_mode in ["Default", "Crop"]:
180
- pass
181
- elif image_process_mode == "Resize":
182
- image = image.resize((336, 336))
183
- else:
184
- raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
185
- max_hw, min_hw = max(image.size), min(image.size)
186
- aspect_ratio = max_hw / min_hw
187
- max_len, min_len = 800, 400
188
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
189
- longest_edge = int(shortest_edge * aspect_ratio)
190
- W, H = image.size
191
- if longest_edge != max(image.size):
192
- if H > W:
193
- H, W = longest_edge, shortest_edge
194
- else:
195
- H, W = shortest_edge, longest_edge
196
- image = image.resize((W, H))
197
- if return_pil:
198
- images.append(image)
199
- else:
200
- buffered = BytesIO()
201
- image.save(buffered, format="PNG")
202
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
203
- images.append(img_b64_str)
204
- return images
205
-
206
- def to_gradio_chatbot(self):
207
- ret = []
208
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
209
- if i % 2 == 0:
210
- if type(msg) is tuple:
211
- import base64
212
- from io import BytesIO
213
- msg, image, image_process_mode = msg
214
- max_hw, min_hw = max(image.size), min(image.size)
215
- aspect_ratio = max_hw / min_hw
216
- max_len, min_len = 800, 400
217
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
218
- longest_edge = int(shortest_edge * aspect_ratio)
219
- W, H = image.size
220
- if H > W:
221
- H, W = longest_edge, shortest_edge
222
- else:
223
- H, W = shortest_edge, longest_edge
224
- image = image.resize((W, H))
225
- buffered = BytesIO()
226
- image.save(buffered, format="JPEG")
227
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
228
- img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
229
- msg = img_str + msg.replace('<image>', '').strip()
230
- ret.append([msg, None])
231
- else:
232
- ret.append([msg, None])
233
- else:
234
- ret[-1][-1] = msg
235
- return ret
236
-
237
- def copy(self):
238
- return Conversation(
239
- system=self.system,
240
- roles=self.roles,
241
- messages=[[x, y] for x, y in self.messages],
242
- offset=self.offset,
243
- sep_style=self.sep_style,
244
- sep=self.sep,
245
- sep2=self.sep2,
246
- version=self.version)
247
-
248
- def dict(self):
249
- if len(self.get_images()) > 0:
250
- return {
251
- "system": self.system,
252
- "roles": self.roles,
253
- "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
254
- "offset": self.offset,
255
- "sep": self.sep,
256
- "sep2": self.sep2,
257
- }
258
- return {
259
- "system": self.system,
260
- "roles": self.roles,
261
- "messages": self.messages,
262
- "offset": self.offset,
263
- "sep": self.sep,
264
- "sep2": self.sep2,
265
- }
266
-
267
-
268
- conv_vicuna_v0 = Conversation(
269
- system="A chat between a curious human and an artificial intelligence assistant. "
270
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
271
- roles=("Human", "Assistant"),
272
- messages=(
273
- ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
274
- ("Assistant",
275
- "Renewable energy sources are those that can be replenished naturally in a relatively "
276
- "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
277
- "Non-renewable energy sources, on the other hand, are finite and will eventually be "
278
- "depleted, such as coal, oil, and natural gas. Here are some key differences between "
279
- "renewable and non-renewable energy sources:\n"
280
- "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
281
- "energy sources are finite and will eventually run out.\n"
282
- "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
283
- "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
284
- "and other negative effects.\n"
285
- "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
286
- "have lower operational costs than non-renewable sources.\n"
287
- "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
288
- "locations than non-renewable sources.\n"
289
- "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
290
- "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
291
- "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
292
- "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
293
- ),
294
- offset=2,
295
- sep_style=SeparatorStyle.SINGLE,
296
- sep="###",
297
- )
298
-
299
- conv_vicuna_v1 = Conversation(
300
- system="A chat between a curious user and an artificial intelligence assistant. "
301
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
302
- roles=("USER", "ASSISTANT"),
303
- version="v1",
304
- messages=(),
305
- offset=0,
306
- sep_style=SeparatorStyle.TWO,
307
- sep=" ",
308
- sep2="</s>",
309
- )
310
-
311
- conv_llama_2 = Conversation(
312
- system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
313
-
314
- If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
315
- roles=("USER", "ASSISTANT"),
316
- version="llama_v2",
317
- messages=(),
318
- offset=0,
319
- sep_style=SeparatorStyle.LLAMA_2,
320
- sep="<s>",
321
- sep2="</s>",
322
- )
323
-
324
- conv_llava_llama_2 = Conversation(
325
- system="You are a helpful language and vision assistant. "
326
- "You are able to understand the visual content that the user provides, "
327
- "and assist the user with a variety of tasks using natural language.",
328
- roles=("USER", "ASSISTANT"),
329
- version="llama_v2",
330
- messages=(),
331
- offset=0,
332
- sep_style=SeparatorStyle.LLAMA_2,
333
- sep="<s>",
334
- sep2="</s>",
335
- )
336
-
337
- conv_mpt = Conversation(
338
- system="""<|im_start|>system
339
- A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
340
- roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
341
- version="mpt",
342
- messages=(),
343
- offset=0,
344
- sep_style=SeparatorStyle.MPT,
345
- sep="<|im_end|>",
346
- )
347
-
348
- conv_llava_plain = Conversation(
349
- system="",
350
- roles=("", ""),
351
- messages=(
352
- ),
353
- offset=0,
354
- sep_style=SeparatorStyle.PLAIN,
355
- sep="\n",
356
- )
357
-
358
- conv_llava_v0 = Conversation(
359
- system="A chat between a curious human and an artificial intelligence assistant. "
360
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
361
- roles=("Human", "Assistant"),
362
- messages=(
363
- ),
364
- offset=0,
365
- sep_style=SeparatorStyle.SINGLE,
366
- sep="###",
367
- )
368
-
369
- conv_llava_v0_mmtag = Conversation(
370
- system="A chat between a curious user and an artificial intelligence assistant. "
371
- "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
372
- "The visual content will be provided with the following format: <Image>visual content</Image>.",
373
- roles=("Human", "Assistant"),
374
- messages=(
375
- ),
376
- offset=0,
377
- sep_style=SeparatorStyle.SINGLE,
378
- sep="###",
379
- version="v0_mmtag",
380
- )
381
-
382
- conv_llava_v1 = Conversation(
383
- system="A chat between a curious human and an artificial intelligence assistant. "
384
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
385
- roles=("USER", "ASSISTANT"),
386
- version="v1",
387
- messages=(),
388
- offset=0,
389
- sep_style=SeparatorStyle.TWO,
390
- sep=" ",
391
- sep2="</s>",
392
- )
393
-
394
- conv_llava_v1_mmtag = Conversation(
395
- system="A chat between a curious user and an artificial intelligence assistant. "
396
- "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
397
- "The visual content will be provided with the following format: <Image>visual content</Image>.",
398
- roles=("USER", "ASSISTANT"),
399
- messages=(),
400
- offset=0,
401
- sep_style=SeparatorStyle.TWO,
402
- sep=" ",
403
- sep2="</s>",
404
- version="v1_mmtag",
405
- )
406
-
407
- conv_mistral_instruct = Conversation(
408
- system="",
409
- roles=("USER", "ASSISTANT"),
410
- version="llama_v2",
411
- messages=(),
412
- offset=0,
413
- sep_style=SeparatorStyle.LLAMA_2,
414
- sep="",
415
- sep2="</s>",
416
- )
417
-
418
- default_conversation = conv_vicuna_v1
419
- conv_templates = {
420
- "default": conv_vicuna_v0,
421
- "v0": conv_vicuna_v0,
422
- "v1": conv_vicuna_v1,
423
- "vicuna_v1": conv_vicuna_v1,
424
- "llama_2": conv_llama_2,
425
- "mistral_instruct": conv_mistral_instruct,
426
-
427
- "plain": conv_llava_plain,
428
- "v0_plain": conv_llava_plain,
429
- "llava_v0": conv_llava_v0,
430
- "v0_mmtag": conv_llava_v0_mmtag,
431
- "llava_v1": conv_llava_v1,
432
- "v1_mmtag": conv_llava_v1_mmtag,
433
- "llava_llama_2": conv_llava_llama_2,
434
- "mpt": conv_mpt,
435
- }
436
-
437
-
438
- if __name__ == "__main__":
439
- print(default_conversation.get_prompt())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/eval/eval_multimodal_chat_gpt_score.py DELETED
@@ -1,112 +0,0 @@
1
- import os
2
- import json
3
- import argparse
4
- from copy import deepcopy
5
- import itertools
6
- from typing import Any
7
- from operator import add
8
- from pprint import pprint
9
- from typing import List
10
- from pathlib import Path
11
- from tqdm import tqdm
12
-
13
- import llm
14
- import util
15
-
16
-
17
- INSTRUCT_PROMPT = """We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with caption describing the same image.
18
- Please rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.
19
- Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."""
20
- ROLE = 'Assistant'
21
-
22
- # Generate instruction for GPT-4 to score the two answers.
23
- def conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2):
24
- return (f'[Context]\n'
25
- f'Figure Caption:\n{fig_label}: {fig_caption}\n\n'
26
- f'Figure Context:\n\t- {fig_context}\n\n'
27
- f'[Question]\n{question}\n\n'
28
- f'[{ROLE} 1]\n{ans1}\n\n[End of {ROLE} 1]\n\n'
29
- f'[{ROLE} 2]\n{ans2}\n\n[End of {ROLE} 2]\n\n'
30
- f'[System]\n{INSTRUCT_PROMPT}\n\n')
31
-
32
- def compare_messages_gen(fig_label, fig_caption, fig_context, question, ans1, ans2):
33
- messages = [
34
- {"role": "system", "content": """'You are a helpful and precise assistant for checking the quality of the answer."""},
35
- ]
36
- messages.append({"role": "user", "content": conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2)})
37
- return messages
38
-
39
-
40
- def sum_list_list(x):
41
- return sum(item for inner_list in x for item in inner_list)
42
-
43
- def chunk(lst, n):
44
- for i in range(0, len(lst), n):
45
- if i+(1.5*n)<len(lst):
46
- end = i + n
47
- else:
48
- end = len(lst)
49
- yield lst[i:end]
50
- if end==len(lst):
51
- return
52
-
53
-
54
- def infer(samples):
55
- model_inst = llm.GPT("gpt-4-0314")
56
-
57
- BATCH_SIZE = 1
58
- batch_samples = []
59
- results = []
60
- batch = []
61
-
62
- print('Starting Multimodal Chat GPT Scoring Eval')
63
-
64
- for sample in tqdm(samples):
65
- sample_copy = deepcopy(sample)
66
- input_msg = compare_messages_gen(sample_copy['fig_label'], sample_copy['fig_caption'], sample_copy['in_text_mention'], sample_copy['question'], sample_copy['ans1'], sample_copy['ans2'])
67
- batch.append(input_msg)
68
- batch_samples.append(sample_copy)
69
- if len(batch)>=BATCH_SIZE:
70
- inference_results = [x.strip() for chunk_messages in chunk([x for x in batch if x], BATCH_SIZE) for x in model_inst.infer(chunk_messages)]
71
- for item, inference_result in zip(batch_samples, inference_results):
72
- item['gpt_eval'] = inference_result
73
- results.extend(batch_samples)
74
- batch = []
75
- batch_samples = []
76
- inference_results = [x.strip() for chunk_messages in chunk([x for x in batch if x], BATCH_SIZE) for x in model_inst.infer(chunk_messages)]
77
- for item, inference_result in zip(batch_samples, inference_results):
78
- item['gpt_eval'] = inference_result
79
- results.extend(batch_samples)
80
- print(f"Result Size: {len(results)}")
81
- return results
82
-
83
-
84
- def main(args):
85
- answer_data = util.load_file_jsonl(args.answers_file)
86
- question_data = util.load_file_jsonl(args.question_file)
87
-
88
- samples = []
89
- for question, answer in zip(question_data, answer_data):
90
- question_copy = deepcopy(question)
91
- question['question'] = question_copy['text']
92
- question['ans1'] = question_copy.pop('gpt4_answer')
93
- question['ans2'] = answer['text']
94
- samples.append(question)
95
-
96
- results = infer(samples)
97
-
98
- # Create parent directory of output score files if it doesn't exist
99
- os.makedirs(Path(args.scores_file).parent, exist_ok=True)
100
-
101
- with open(args.scores_file, 'w') as f:
102
- for row in results:
103
- f.write(json.dumps(row)+'\n')
104
-
105
-
106
- if __name__ == '__main__':
107
- parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Scoring", add_help=True)
108
- parser.add_argument("--answers-file", default="", metavar="FILE", help="path to model answer file")
109
- parser.add_argument("--question-file", default="data/questions/llava_med_eval_qa50_qa.jsonl", metavar="FILE", help="path to multichat questions file")
110
- parser.add_argument("--scores-file", default="", metavar="FILE", help="path to save gpt-4 score file")
111
- args = parser.parse_args()
112
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/eval/llm.py DELETED
@@ -1,134 +0,0 @@
1
- import os
2
- import abc
3
- import asyncio
4
- from abc import abstractmethod
5
- import math
6
-
7
- import tiktoken
8
- import openai
9
- import backoff
10
-
11
-
12
- class LLM(abc.ABC):
13
-
14
- prompt_percent = 0.9
15
-
16
- @abstractmethod
17
- def __init__(self):
18
- raise NotImplementedError("Subclasses should implement this!")
19
-
20
- @abstractmethod
21
- def infer(self, prompts):
22
- raise NotImplementedError("Subclasses should implement this!")
23
-
24
- @abstractmethod
25
- def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
26
- raise NotImplementedError("Subclasses should implement this!")
27
-
28
-
29
- class GPT(LLM):
30
-
31
- prompt_percent = 0.8
32
-
33
- openai_cxn_dict = {
34
- 'default': {
35
- 'endpoint': "INSERT YOUR AZURE OPENAI ENDPOINT HERE",
36
- 'api_key': "INSERT YOUR AZURE OPENAI API KEY HERE",
37
- },
38
- }
39
-
40
- deployment_max_length_dict = {
41
- 'gpt-4': 8192,
42
- 'gpt-4-0314': 8192,
43
- 'gpt-4-32k': 32768,
44
- 'gpt-35-turbo': 4096,
45
- 'gpt-35-turbo-16k': 16385,
46
- }
47
-
48
- def __init__(self, model_id):
49
- self.temperature = 0.0
50
- self.top_k = 1
51
- self.encoding = tiktoken.encoding_for_model("-".join(model_id.split("-", 2)[:2]).replace('5', '.5'))
52
- self.openai_api = 'default'
53
- self.model_id = model_id
54
- self.max_length = self.deployment_max_length_dict[model_id]
55
- self.client = openai.AsyncAzureOpenAI(
56
- api_key=self.openai_cxn_dict[self.openai_api]['api_key'],
57
- api_version="2023-12-01-preview",
58
- azure_endpoint=self.openai_cxn_dict[self.openai_api]['endpoint']
59
- )
60
-
61
- def gen_messages(self, fixed_instruction, few_shot_examples, input, input_header, output_header):
62
- messages = [
63
- {
64
- "role": "system",
65
- "content": fixed_instruction,
66
- },
67
- ]
68
- for example in few_shot_examples:
69
- messages.extend(
70
- [
71
- {
72
- "role": "user",
73
- "content": input_header+'\n'+example['user']+'\n\n'+output_header,
74
- },
75
- {
76
- "role": "assistant",
77
- "content": example['assistant'],
78
- },
79
- ]
80
- )
81
- messages.extend(
82
- [
83
- {
84
- "role": "user",
85
- "content": input_header+'\n'+input+'\n\n'+output_header,
86
- },
87
- ]
88
- )
89
- return messages
90
-
91
- # Define the coroutine for making API calls to GPT
92
- @backoff.on_exception(backoff.expo, openai.RateLimitError)
93
- async def make_api_call_to_gpt(
94
- self,
95
- messages
96
- ):
97
- response = await self.client.chat.completions.create(
98
- model=self.model_id,
99
- messages=messages,
100
- temperature=self.temperature,
101
- )
102
- return response.choices[0].message.content
103
-
104
- async def dispatch_openai_requests(
105
- self,
106
- messages_list,
107
- ):
108
- # Asynchronously call the function for each prompt
109
- tasks = [self.make_api_call_to_gpt(messages) for messages in messages_list]
110
-
111
- # Gather and run the tasks concurrently
112
- results = await asyncio.gather(*tasks)
113
- return results
114
-
115
- def infer(self,
116
- messages_list,
117
- ):
118
- return asyncio.run(self.dispatch_openai_requests(messages_list))
119
-
120
- def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
121
- # Tokenize fixed_prompt
122
- fixed_token_ids = self.encoding.encode(fixed_instruction+' '.join([x['user']+' '+x['assistant'] for x in few_shot_examples]))
123
- # Calculate remaining token length
124
- remaining_token_len = math.ceil((self.prompt_percent*self.max_length)-len(fixed_token_ids))
125
-
126
- # Tokenize splittable_input
127
- split_token_ids = self.encoding.encode(splittable_input)
128
-
129
- # Split tokenized split_prompt into list of individual inputs strings. Uses tokens to calculate length
130
- split_token_ids_list = [split_token_ids[i:i+remaining_token_len+10] for i in range(0, len(split_token_ids), remaining_token_len)]
131
- split_input_list = [self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list]
132
-
133
- # Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
134
- return [self.gen_messages(fixed_instruction, few_shot_examples, split_input, input_header, output_header) for split_input in split_input_list]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/eval/model_vqa.py DELETED
@@ -1,109 +0,0 @@
1
- import argparse
2
- import torch
3
- import os
4
- import json
5
- from tqdm import tqdm
6
- import shortuuid
7
-
8
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
- from llava.conversation import conv_templates, SeparatorStyle
10
- from llava.model.builder import load_pretrained_model
11
- from llava.utils import disable_torch_init
12
- from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images
13
-
14
- from PIL import Image
15
- import math
16
- from transformers import set_seed, logging
17
-
18
- logging.set_verbosity_error()
19
-
20
-
21
- def split_list(lst, n):
22
- """Split a list into n (roughly) equal-sized chunks"""
23
- chunk_size = math.ceil(len(lst) / n) # integer division
24
- return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
25
-
26
-
27
- def get_chunk(lst, n, k):
28
- chunks = split_list(lst, n)
29
- return chunks[k]
30
-
31
-
32
- def eval_model(args):
33
- set_seed(0)
34
- # Model
35
- disable_torch_init()
36
- model_path = os.path.expanduser(args.model_path)
37
- model_name = get_model_name_from_path(model_path)
38
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
39
-
40
- questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
41
- questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
42
- answers_file = os.path.expanduser(args.answers_file)
43
- os.makedirs(os.path.dirname(answers_file), exist_ok=True)
44
- ans_file = open(answers_file, "w")
45
- for line in tqdm(questions):
46
- idx = line["question_id"]
47
- image_file = line["image"]
48
- qs = line["text"].replace(DEFAULT_IMAGE_TOKEN, '').strip()
49
- cur_prompt = qs
50
- if model.config.mm_use_im_start_end:
51
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
52
- else:
53
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
54
-
55
- conv = conv_templates[args.conv_mode].copy()
56
- conv.append_message(conv.roles[0], qs)
57
- conv.append_message(conv.roles[1], None)
58
- prompt = conv.get_prompt()
59
-
60
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
61
-
62
- image = Image.open(os.path.join(args.image_folder, image_file))
63
- image_tensor = process_images([image], image_processor, model.config)[0]
64
-
65
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
66
- keywords = [stop_str]
67
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
68
-
69
- with torch.inference_mode():
70
- output_ids = model.generate(
71
- input_ids,
72
- images=image_tensor.unsqueeze(0).half().cuda(),
73
- do_sample=True if args.temperature > 0 else False,
74
- temperature=args.temperature,
75
- top_p=args.top_p,
76
- num_beams=args.num_beams,
77
- # no_repeat_ngram_size=3,
78
- max_new_tokens=1024,
79
- use_cache=True)
80
-
81
- outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
82
-
83
- ans_id = shortuuid.uuid()
84
- ans_file.write(json.dumps({"question_id": idx,
85
- "prompt": cur_prompt,
86
- "text": outputs,
87
- "answer_id": ans_id,
88
- "model_id": model_name,
89
- "metadata": {}}) + "\n")
90
- ans_file.flush()
91
- ans_file.close()
92
-
93
-
94
- if __name__ == "__main__":
95
- parser = argparse.ArgumentParser()
96
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
97
- parser.add_argument("--model-base", type=str, default=None)
98
- parser.add_argument("--image-folder", type=str, default="")
99
- parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
100
- parser.add_argument("--answers-file", type=str, default="answer.jsonl")
101
- parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
102
- parser.add_argument("--num-chunks", type=int, default=1)
103
- parser.add_argument("--chunk-idx", type=int, default=0)
104
- parser.add_argument("--temperature", type=float, default=0.2)
105
- parser.add_argument("--top_p", type=float, default=None)
106
- parser.add_argument("--num_beams", type=int, default=1)
107
- args = parser.parse_args()
108
-
109
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/eval/run_llava.py DELETED
@@ -1,145 +0,0 @@
1
- import argparse
2
- import torch
3
-
4
- from llava.constants import (
5
- IMAGE_TOKEN_INDEX,
6
- DEFAULT_IMAGE_TOKEN,
7
- DEFAULT_IM_START_TOKEN,
8
- DEFAULT_IM_END_TOKEN,
9
- IMAGE_PLACEHOLDER,
10
- )
11
- from llava.conversation import conv_templates, SeparatorStyle
12
- from llava.model.builder import load_pretrained_model
13
- from llava.utils import disable_torch_init
14
- from llava.mm_utils import (
15
- process_images,
16
- tokenizer_image_token,
17
- get_model_name_from_path,
18
- )
19
-
20
- from PIL import Image
21
-
22
- import requests
23
- from PIL import Image
24
- from io import BytesIO
25
- import re
26
-
27
-
28
- def image_parser(args):
29
- out = args.image_file.split(args.sep)
30
- return out
31
-
32
-
33
- def load_image(image_file):
34
- if image_file.startswith("http") or image_file.startswith("https"):
35
- response = requests.get(image_file)
36
- image = Image.open(BytesIO(response.content)).convert("RGB")
37
- else:
38
- image = Image.open(image_file).convert("RGB")
39
- return image
40
-
41
-
42
- def load_images(image_files):
43
- out = []
44
- for image_file in image_files:
45
- image = load_image(image_file)
46
- out.append(image)
47
- return out
48
-
49
-
50
- def eval_model(args):
51
- # Model
52
- disable_torch_init()
53
-
54
- model_name = get_model_name_from_path(args.model_path)
55
- tokenizer, model, image_processor, context_len = load_pretrained_model(
56
- args.model_path, args.model_base, model_name
57
- )
58
-
59
- qs = args.query
60
- image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
61
- if IMAGE_PLACEHOLDER in qs:
62
- if model.config.mm_use_im_start_end:
63
- qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
64
- else:
65
- qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
66
- else:
67
- if model.config.mm_use_im_start_end:
68
- qs = image_token_se + "\n" + qs
69
- else:
70
- qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
71
-
72
- if "llama-2" in model_name.lower():
73
- conv_mode = "llava_llama_2"
74
- elif "mistral" in model_name.lower():
75
- conv_mode = "mistral_instruct"
76
- elif "v1.6-34b" in model_name.lower():
77
- conv_mode = "chatml_direct"
78
- elif "v1" in model_name.lower():
79
- conv_mode = "llava_v1"
80
- elif "mpt" in model_name.lower():
81
- conv_mode = "mpt"
82
- else:
83
- conv_mode = "llava_v0"
84
-
85
- if args.conv_mode is not None and conv_mode != args.conv_mode:
86
- print(
87
- "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
88
- conv_mode, args.conv_mode, args.conv_mode
89
- )
90
- )
91
- else:
92
- args.conv_mode = conv_mode
93
-
94
- conv = conv_templates[args.conv_mode].copy()
95
- conv.append_message(conv.roles[0], qs)
96
- conv.append_message(conv.roles[1], None)
97
- prompt = conv.get_prompt()
98
-
99
- image_files = image_parser(args)
100
- images = load_images(image_files)
101
- image_sizes = [x.size for x in images]
102
- images_tensor = process_images(
103
- images,
104
- image_processor,
105
- model.config
106
- ).to(model.device, dtype=torch.float16)
107
-
108
- input_ids = (
109
- tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
110
- .unsqueeze(0)
111
- .cuda()
112
- )
113
-
114
- with torch.inference_mode():
115
- output_ids = model.generate(
116
- input_ids,
117
- images=images_tensor,
118
- image_sizes=image_sizes,
119
- do_sample=True if args.temperature > 0 else False,
120
- temperature=args.temperature,
121
- top_p=args.top_p,
122
- num_beams=args.num_beams,
123
- max_new_tokens=args.max_new_tokens,
124
- use_cache=True,
125
- )
126
-
127
- outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
128
- print(outputs)
129
-
130
-
131
- if __name__ == "__main__":
132
- parser = argparse.ArgumentParser()
133
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
134
- parser.add_argument("--model-base", type=str, default=None)
135
- parser.add_argument("--image-file", type=str, required=True)
136
- parser.add_argument("--query", type=str, required=True)
137
- parser.add_argument("--conv-mode", type=str, default=None)
138
- parser.add_argument("--sep", type=str, default=",")
139
- parser.add_argument("--temperature", type=float, default=0.2)
140
- parser.add_argument("--top_p", type=float, default=None)
141
- parser.add_argument("--num_beams", type=int, default=1)
142
- parser.add_argument("--max_new_tokens", type=int, default=512)
143
- args = parser.parse_args()
144
-
145
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/eval/summarize_gpt_review.py DELETED
@@ -1,47 +0,0 @@
1
- import argparse
2
- from copy import deepcopy
3
- import util
4
- from pprint import pprint
5
- from collections import defaultdict
6
- import pandas as pd
7
- import json
8
-
9
-
10
- def get_domain(x):
11
- for domain in ['chest_xray', 'mri', 'histology', 'gross', 'ct_scan']:
12
- in_domain = x['domain'][domain]
13
- if in_domain:
14
- return domain
15
-
16
-
17
-
18
- def main(args):
19
- scores_data = util.load_file_jsonl(args.scores_file)
20
- predictions = [(x['question_id'], x['type'], get_domain(x), x['gpt_eval'].split('\n')[0].split(' ')) for x in scores_data]
21
-
22
- score_type_dict = defaultdict(lambda: defaultdict(list))
23
- for q_id, q_type, domain, (a1_score, a2_score) in predictions:
24
- score_type_dict[q_type][1].append(a1_score)
25
- score_type_dict[q_type][2].append(a2_score)
26
- score_type_dict['overall'][1].append(a1_score)
27
- score_type_dict['overall'][2].append(a2_score)
28
- score_type_dict[domain][1].append(a1_score)
29
- score_type_dict[domain][2].append(a2_score)
30
-
31
- result = defaultdict(dict)
32
-
33
- for q_type, score_dict in score_type_dict.items():
34
- result[q_type]['gpt4_score'] = util.get_avg(score_dict[1])
35
- result[q_type]['pred_score'] = util.get_avg(score_dict[2])
36
- result[q_type]['pred_relative_score'] = util.get_avg([float(s2)/float(s1) for s1, s2 in zip(score_dict[1], score_dict[2])])*100
37
- result[q_type]['data_size'] = len(score_dict[1])
38
-
39
- df = pd.DataFrame.from_dict(result).filter(['conversation', 'detailed_description', 'chest_xray', 'mri', 'histology', 'gross', 'ct_scan', 'overall'])
40
- print(df)
41
-
42
-
43
- if __name__ == '__main__':
44
- parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Eval Postprocessing", add_help=True)
45
- parser.add_argument("--scores-file", default="", metavar="FILE", help="input path to gpt-4 score file")
46
- args = parser.parse_args()
47
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/eval/util.py DELETED
@@ -1,9 +0,0 @@
1
- import json
2
-
3
-
4
- def load_file_jsonl(path):
5
- with open(path) as f:
6
- return [json.loads(row) for row in f]
7
-
8
- def get_avg(x):
9
- return sum([float(y) for y in x])/len(x)
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/mm_utils.py DELETED
@@ -1,110 +0,0 @@
1
- from PIL import Image
2
- from io import BytesIO
3
- import base64
4
- import random
5
- import torch
6
- from transformers import StoppingCriteria
7
- from llava.constants import IMAGE_TOKEN_INDEX
8
-
9
-
10
- def load_image_from_base64(image):
11
- return Image.open(BytesIO(base64.b64decode(image)))
12
-
13
-
14
- def expand2square(pil_img, background_color):
15
- width, height = pil_img.size
16
- if width == height:
17
- return pil_img
18
- elif width > height:
19
- result = Image.new(pil_img.mode, (width, width), background_color)
20
- # sample a random between 0 and (width - height) // 2
21
- y_start = random.randint((width - height) // 2, (width - height) // 2 + 1)
22
- result.paste(pil_img, (0, y_start))
23
- return result
24
- else:
25
- result = Image.new(pil_img.mode, (height, height), background_color)
26
- # sample a random between 0 and (height - width) // 2
27
- x_start = random.randint((height - width) // 2, (height - width) // 2 + 1)
28
- result.paste(pil_img, (x_start, 0))
29
- return result
30
-
31
-
32
- def process_images(images, image_processor, model_cfg):
33
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
34
- new_images = []
35
- for image in images:
36
- if image_aspect_ratio == 'pad':
37
- if image.mode=='L':
38
- background_color = int(255*sum(image_processor.image_mean)/len(image_processor.image_mean))
39
- else:
40
- background_color = tuple(int(x*255) for x in image_processor.image_mean)
41
- image = expand2square(image, background_color)
42
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
43
- new_images.append(image)
44
- if all(x.shape == new_images[0].shape for x in new_images):
45
- new_images = torch.stack(new_images, dim=0)
46
- return new_images
47
-
48
-
49
- def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
50
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
51
-
52
- def insert_separator(X, sep):
53
- return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
54
-
55
- input_ids = []
56
- offset = 0
57
- if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
58
- offset = 1
59
- input_ids.append(prompt_chunks[0][0])
60
-
61
- for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
62
- input_ids.extend(x[offset:])
63
-
64
- if return_tensors is not None:
65
- if return_tensors == 'pt':
66
- return torch.tensor(input_ids, dtype=torch.long)
67
- raise ValueError(f'Unsupported tensor type: {return_tensors}')
68
- return input_ids
69
-
70
-
71
- def get_model_name_from_path(model_path):
72
- model_path = model_path.strip("/")
73
- model_paths = model_path.split("/")
74
- if model_paths[-1].startswith('checkpoint-'):
75
- return model_paths[-2] + "_" + model_paths[-1]
76
- else:
77
- return model_paths[-1]
78
-
79
- class KeywordsStoppingCriteria(StoppingCriteria):
80
- def __init__(self, keywords, tokenizer, input_ids):
81
- self.keywords = keywords
82
- self.keyword_ids = []
83
- self.max_keyword_len = 0
84
- for keyword in keywords:
85
- cur_keyword_ids = tokenizer(keyword).input_ids
86
- if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
87
- cur_keyword_ids = cur_keyword_ids[1:]
88
- if len(cur_keyword_ids) > self.max_keyword_len:
89
- self.max_keyword_len = len(cur_keyword_ids)
90
- self.keyword_ids.append(torch.tensor(cur_keyword_ids))
91
- self.tokenizer = tokenizer
92
- self.start_len = input_ids.shape[1]
93
-
94
- def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
95
- offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
96
- self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
97
- for keyword_id in self.keyword_ids:
98
- if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
99
- return True
100
- outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
101
- for keyword in self.keywords:
102
- if keyword in outputs:
103
- return True
104
- return False
105
-
106
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
107
- outputs = []
108
- for i in range(output_ids.shape[0]):
109
- outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
110
- return all(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/model/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
 
 
LLaVA-Med/llava/model/builder.py DELETED
@@ -1,83 +0,0 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
2
- import torch
3
- from llava.model import LlavaMistralForCausalLM
4
- from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
-
6
-
7
- def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
8
-
9
- kwargs = {}
10
-
11
- if device != "cuda":
12
- kwargs['device_map'] = {"": device}
13
-
14
- if load_8bit:
15
- kwargs['load_in_8bit'] = True
16
- elif load_4bit:
17
- kwargs['load_in_4bit'] = True
18
- kwargs['quantization_config'] = BitsAndBytesConfig(
19
- load_in_4bit=True,
20
- bnb_4bit_compute_dtype=torch.float16,
21
- bnb_4bit_use_double_quant=True,
22
- bnb_4bit_quant_type='nf4'
23
- )
24
- else:
25
- kwargs['torch_dtype'] = torch.float16
26
-
27
- if 'llava' in model_name.lower():
28
- # Load LLaVA model
29
- if 'mistral' in model_name.lower():
30
- tokenizer = AutoTokenizer.from_pretrained(model_path)
31
- model = LlavaMistralForCausalLM.from_pretrained(
32
- model_path,
33
- low_cpu_mem_usage=False,
34
- use_flash_attention_2=False,
35
- **kwargs
36
- )
37
- else:
38
- # Load language model
39
- if model_base is not None:
40
- # PEFT model
41
- from peft import PeftModel
42
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
43
- model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
44
- print(f"Loading LoRA weights from {model_path}")
45
- model = PeftModel.from_pretrained(model, model_path)
46
- print(f"Merging weights")
47
- model = model.merge_and_unload()
48
- print('Convert to FP16...')
49
- model.to(torch.float16)
50
- else:
51
- use_fast = False
52
- if 'mpt' in model_name.lower():
53
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
54
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
55
- else:
56
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
57
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
58
-
59
- image_processor = None
60
-
61
- if 'llava' in model_name.lower(): # or 'mistral' in model_name.lower():
62
- mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
63
- mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
64
- if mm_use_im_patch_token:
65
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
66
- if mm_use_im_start_end:
67
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
68
- model.resize_token_embeddings(len(tokenizer))
69
-
70
- vision_tower = model.get_vision_tower()
71
- if not vision_tower.is_loaded:
72
- vision_tower.load_model()
73
- vision_tower.to(device=device, dtype=torch.float16)
74
- model.model.mm_projector.to(device=device, dtype=torch.float16)
75
- model.to(device=device, dtype=torch.float16)
76
- image_processor = vision_tower.image_processor
77
-
78
- if hasattr(model.config, "max_sequence_length"):
79
- context_len = model.config.max_sequence_length
80
- else:
81
- context_len = 2048
82
-
83
- return tokenizer, model, image_processor, context_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/model/builders.py DELETED
@@ -1,152 +0,0 @@
1
- import os
2
- import warnings
3
- import shutil
4
-
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
6
- import torch
7
- from llava.model import LLavaMistralForCausalLM
8
- from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
-
10
-
11
- def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
12
- kwargs = {"device_map": device_map, **kwargs}
13
-
14
- if device != "cuda":
15
- kwargs['device_map'] = {"": device}
16
-
17
- if load_8bit:
18
- kwargs['load_in_8bit'] = True
19
- elif load_4bit:
20
- kwargs['load_in_4bit'] = True
21
- kwargs['quantization_config'] = BitsAndBytesConfig(
22
- load_in_4bit=True,
23
- bnb_4bit_compute_dtype=torch.float16,
24
- bnb_4bit_use_double_quant=True,
25
- bnb_4bit_quant_type='nf4'
26
- )
27
- else:
28
- kwargs['torch_dtype'] = torch.float16
29
-
30
- if use_flash_attn:
31
- kwargs['attn_implementation'] = 'flash_attention_2'
32
-
33
- if 'llava' in model_name.lower():
34
- # Load LLaVA model
35
- if 'lora' in model_name.lower() and model_base is None:
36
- warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
37
- if 'lora' in model_name.lower() and model_base is not None:
38
- from llava.model.language_model.llava_mistral import LlavaMistralConfig
39
- lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
40
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
41
- print('Loading LLaVA from base model...')
42
- model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
43
- token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
44
- if model.lm_head.weight.shape[0] != token_num:
45
- model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
46
- model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
47
-
48
- # print('Loading additional LLaVA weights...')
49
- # if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
50
- # non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
51
- # else:
52
- # # this is probably from HF Hub
53
- # from huggingface_hub import hf_hub_download
54
- # def load_from_hf(repo_id, filename, subfolder=None):
55
- # cache_file = hf_hub_download(
56
- # repo_id=repo_id,
57
- # filename=filename,
58
- # subfolder=subfolder)
59
- # return torch.load(cache_file, map_location='cpu')
60
- # non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
61
- # non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
62
- # if any(k.startswith('model.model.') for k in non_lora_trainables):
63
- # non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
64
- # model.load_state_dict(non_lora_trainables, strict=False)
65
-
66
- from peft import PeftModel
67
- print('Loading LoRA weights...')
68
- model = PeftModel.from_pretrained(model, model_path)
69
- print('Merging LoRA weights...')
70
- model = model.merge_and_unload()
71
- print('Model is loaded...')
72
- elif model_base is not None:
73
- # this may be mm projector only
74
- print('Loading LLaVA from base model...')
75
- if 'mpt' in model_name.lower():
76
- if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
77
- shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
78
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
79
- cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
80
- model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
81
- else:
82
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
83
- cfg_pretrained = AutoConfig.from_pretrained(model_path)
84
- model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
85
-
86
- mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
87
- mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
88
- model.load_state_dict(mm_projector_weights, strict=False)
89
- else:
90
- if 'mpt' in model_name.lower():
91
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
92
- model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
93
- elif 'mistral' in model_name.lower():
94
- tokenizer = AutoTokenizer.from_pretrained(model_path)
95
- model = LlavaMistralForCausalLM.from_pretrained(
96
- model_path,
97
- low_cpu_mem_usage=True,
98
- **kwargs
99
- )
100
- else:
101
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
102
- model = LlavaLlamaForCausalLM.from_pretrained(
103
- model_path,
104
- low_cpu_mem_usage=True,
105
- **kwargs
106
- )
107
- else:
108
- # Load language model
109
- if model_base is not None:
110
- # PEFT model
111
- from peft import PeftModel
112
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
113
- model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
114
- print(f"Loading LoRA weights from {model_path}")
115
- model = PeftModel.from_pretrained(model, model_path)
116
- print(f"Merging weights")
117
- model = model.merge_and_unload()
118
- print('Convert to FP16...')
119
- model.to(torch.float16)
120
- else:
121
- use_fast = False
122
- if 'mpt' in model_name.lower():
123
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
124
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
125
- else:
126
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
127
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
128
-
129
- image_processor = None
130
-
131
- if 'mistral' in model_name.lower():
132
- mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
133
- mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
134
- if mm_use_im_patch_token:
135
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
136
- if mm_use_im_start_end:
137
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
138
- model.resize_token_embeddings(len(tokenizer))
139
-
140
- vision_tower = model.get_vision_tower()
141
- if not vision_tower.is_loaded:
142
- vision_tower.load_model(device_map=device_map)
143
- if device_map != 'auto':
144
- vision_tower.to(device=device_map, dtype=torch.float16)
145
- image_processor = vision_tower.image_processor
146
-
147
- if hasattr(model.config, "max_sequence_length"):
148
- context_len = model.config.max_sequence_length
149
- else:
150
- context_len = 2048
151
-
152
- return tokenizer, model, image_processor, context_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/model/language_model/llava_mistral.py DELETED
@@ -1,143 +0,0 @@
1
- from typing import List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
-
6
- from transformers import AutoConfig, AutoModelForCausalLM, \
7
- MistralConfig, MistralModel, MistralForCausalLM
8
-
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.generation.utils import GenerateOutput
11
-
12
- from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
13
-
14
-
15
- class LlavaMistralConfig(MistralConfig):
16
- model_type = "llava_mistral"
17
-
18
-
19
- class LlavaMistralModel(LlavaMetaModel, MistralModel):
20
- config_class = LlavaMistralConfig
21
-
22
- def __init__(self, config: MistralConfig):
23
- super(LlavaMistralModel, self).__init__(config)
24
-
25
-
26
- class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
27
- config_class = LlavaMistralConfig
28
-
29
- def __init__(self, config):
30
- super(MistralForCausalLM, self).__init__(config)
31
- self.model = LlavaMistralModel(config)
32
-
33
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
34
-
35
- # Initialize weights and apply final processing
36
- self.post_init()
37
-
38
- def get_model(self):
39
- return self.model
40
-
41
- def forward(
42
- self,
43
- input_ids: torch.LongTensor = None,
44
- attention_mask: Optional[torch.Tensor] = None,
45
- position_ids: Optional[torch.LongTensor] = None,
46
- past_key_values: Optional[List[torch.FloatTensor]] = None,
47
- inputs_embeds: Optional[torch.FloatTensor] = None,
48
- labels: Optional[torch.LongTensor] = None,
49
- use_cache: Optional[bool] = None,
50
- output_attentions: Optional[bool] = None,
51
- output_hidden_states: Optional[bool] = None,
52
- images: Optional[torch.FloatTensor] = None,
53
- image_sizes: Optional[List[List[int]]] = None,
54
- return_dict: Optional[bool] = None,
55
- ) -> Union[Tuple, CausalLMOutputWithPast]:
56
-
57
- if inputs_embeds is None:
58
- (
59
- input_ids,
60
- position_ids,
61
- attention_mask,
62
- past_key_values,
63
- inputs_embeds,
64
- labels
65
- ) = self.prepare_inputs_labels_for_multimodal(
66
- input_ids,
67
- position_ids,
68
- attention_mask,
69
- past_key_values,
70
- labels,
71
- images,
72
- image_sizes
73
- )
74
-
75
- return super().forward(
76
- input_ids=input_ids,
77
- attention_mask=attention_mask,
78
- position_ids=position_ids,
79
- past_key_values=past_key_values,
80
- inputs_embeds=inputs_embeds,
81
- labels=labels,
82
- use_cache=use_cache,
83
- output_attentions=output_attentions,
84
- output_hidden_states=output_hidden_states,
85
- return_dict=return_dict
86
- )
87
-
88
- @torch.no_grad()
89
- def generate(
90
- self,
91
- inputs: Optional[torch.Tensor] = None,
92
- images: Optional[torch.Tensor] = None,
93
- image_sizes: Optional[torch.Tensor] = None,
94
- **kwargs,
95
- ) -> Union[GenerateOutput, torch.LongTensor]:
96
- position_ids = kwargs.pop("position_ids", None)
97
- attention_mask = kwargs.pop("attention_mask", None)
98
- if "inputs_embeds" in kwargs:
99
- raise NotImplementedError("`inputs_embeds` is not supported")
100
-
101
- if images is not None:
102
- (
103
- inputs,
104
- position_ids,
105
- attention_mask,
106
- _,
107
- inputs_embeds,
108
- _
109
- ) = self.prepare_inputs_labels_for_multimodal(
110
- inputs,
111
- position_ids,
112
- attention_mask,
113
- None,
114
- None,
115
- images,
116
- image_sizes=image_sizes
117
- )
118
- else:
119
- inputs_embeds = self.get_model().embed_tokens(inputs)
120
-
121
- return super().generate(
122
- position_ids=position_ids,
123
- attention_mask=attention_mask,
124
- inputs_embeds=inputs_embeds,
125
- **kwargs
126
- )
127
-
128
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
129
- inputs_embeds=None, **kwargs):
130
- images = kwargs.pop("images", None)
131
- image_sizes = kwargs.pop("image_sizes", None)
132
- inputs = super().prepare_inputs_for_generation(
133
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
134
- )
135
- if images is not None:
136
- inputs['images'] = images
137
- if image_sizes is not None:
138
- inputs['image_sizes'] = image_sizes
139
- return inputs
140
-
141
-
142
- AutoConfig.register("llava_mistral", LlavaMistralConfig)
143
- AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/model/llava_arch.py DELETED
@@ -1,309 +0,0 @@
1
- # Copyright 2023 Haotian Liu
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from abc import ABC, abstractmethod
17
- import os
18
- from glob import glob
19
-
20
- import torch
21
- import torch.nn as nn
22
-
23
- from .multimodal_encoder.builder import build_vision_tower
24
- from .multimodal_projector.builder import build_vision_projector
25
-
26
- from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
27
-
28
-
29
- class LlavaMetaModel:
30
-
31
- def __init__(self, config):
32
- super(LlavaMetaModel, self).__init__(config)
33
-
34
- if hasattr(config, "mm_vision_tower"):
35
- self.vision_tower = build_vision_tower(config, delay_load=True)
36
- self.mm_projector = build_vision_projector(config)
37
-
38
- def get_vision_tower(self):
39
- vision_tower = getattr(self, 'vision_tower', None)
40
- if type(vision_tower) is list:
41
- vision_tower = vision_tower[0]
42
- return vision_tower
43
-
44
- def initialize_vision_modules(self, model_args, fsdp=None, embed_tokens=None):
45
- vision_tower = model_args.vision_tower
46
- mm_vision_select_layer = model_args.mm_vision_select_layer
47
- mm_vision_select_feature = model_args.mm_vision_select_feature
48
- pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
49
-
50
- self.config.mm_vision_tower = vision_tower
51
-
52
- if self.get_vision_tower() is None:
53
- vision_tower = build_vision_tower(model_args)
54
-
55
- if fsdp is not None and len(fsdp) > 0:
56
- self.vision_tower = [vision_tower]
57
- else:
58
- self.vision_tower = vision_tower
59
- else:
60
- if fsdp is not None and len(fsdp) > 0:
61
- vision_tower = self.vision_tower[0]
62
- else:
63
- vision_tower = self.vision_tower
64
- vision_tower.load_model()
65
-
66
- self.config.use_mm_proj = True
67
- self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
68
- self.config.mm_hidden_size = vision_tower.hidden_size
69
- self.config.mm_vision_select_layer = mm_vision_select_layer
70
- self.config.mm_vision_select_feature = mm_vision_select_feature
71
-
72
- # add additional configs for segtok
73
- self.config.feature_outs = model_args.feature_outs
74
- self.config.img_size = model_args.img_size
75
- self.config.vision_backbone = model_args.vision_backbone
76
- self.config.segtok_posembed = model_args.segtok_posembed
77
-
78
- if getattr(self, 'mm_projector', None) is None:
79
- self.mm_projector = build_vision_projector(self.config)
80
- else:
81
- # In case it is frozen by LoRA
82
- for p in self.mm_projector.parameters():
83
- p.requires_grad = True
84
-
85
- # Initialize last layer in mm_projector with weight=0 and bias=mean(embed_tokens)
86
- if embed_tokens is not None:
87
- embed_tokens_weight = embed_tokens.weight.data
88
- self.mm_projector[-1].weight.data.zero_()
89
- self.mm_projector[-1].bias.data.copy_(embed_tokens_weight.mean(dim=0))
90
-
91
- if pretrain_mm_mlp_adapter is not None:
92
- def get_w(weights, keyword):
93
- return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
94
-
95
- mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
96
- self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
97
-
98
- # also load additional learnable parameters during feature alignment
99
- checkpoint_folder = os.path.dirname(pretrain_mm_mlp_adapter)
100
- ckpts = glob(f"{checkpoint_folder}/checkpoint-*", recursive = False)
101
- if len(ckpts) > 0:
102
- vision_module_weights = torch.load(f"{ckpts[-1]}/mm_projector.bin", map_location='cpu')
103
- model_dict = get_w(vision_module_weights, 'vision_tower')
104
- print(f"Loading vision module weights from {ckpts[-1]}/mm_projector.bin")
105
- # print keys in model_dict
106
- print(f"Loaded keys: {model_dict.keys()}")
107
- self.vision_tower.load_state_dict(model_dict, strict=False)
108
-
109
- class LlavaMetaForCausalLM(ABC):
110
-
111
- @abstractmethod
112
- def get_model(self):
113
- pass
114
-
115
- def get_vision_tower(self):
116
- return self.get_model().get_vision_tower()
117
-
118
- def encode_images(self, images):
119
- image_features = self.get_model().get_vision_tower()(images)
120
- image_features = self.get_model().mm_projector(image_features)
121
- return image_features
122
-
123
- def prepare_inputs_labels_for_multimodal(
124
- self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
125
- ):
126
- vision_tower = self.get_vision_tower()
127
- if vision_tower is None or images is None or input_ids.shape[1] == 1:
128
- if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
129
- target_shape = past_key_values[-1][-1].shape[-2] + 1
130
- attention_mask = torch.cat((attention_mask, torch.ones(
131
- (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
132
- dtype=attention_mask.dtype,
133
- device=attention_mask.device
134
- )), dim=1)
135
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
136
- return input_ids, position_ids, attention_mask, past_key_values, None, labels
137
-
138
- if type(images) is list or images.ndim == 5:
139
- concat_images = torch.cat([image for image in images], dim=0)
140
- image_features = self.encode_images(concat_images)
141
- split_sizes = [image.shape[0] for image in images]
142
- image_features = torch.split(image_features, split_sizes, dim=0)
143
- image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
144
- else:
145
- image_features = self.encode_images(images).to(self.device)
146
-
147
- # TODO: image start / end is not implemented here to support pretraining.
148
- if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
149
- raise NotImplementedError
150
-
151
- # Let's just add dummy tensors if they do not exist,
152
- # it is a headache to deal with None all the time.
153
- # But it is not ideal, and if you have a better idea,
154
- # please open an issue / submit a PR, thanks.
155
- _labels = labels
156
- _position_ids = position_ids
157
- _attention_mask = attention_mask
158
-
159
- if attention_mask is None:
160
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
161
- else:
162
- attention_mask = attention_mask.bool()
163
- if position_ids is None:
164
- position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
165
-
166
- if labels is None:
167
- labels = torch.full_like(input_ids, IGNORE_INDEX)
168
-
169
- input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
170
- labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
171
-
172
- new_input_embeds = []
173
- new_labels = []
174
- cur_image_idx = 0
175
- for batch_idx, cur_input_ids in enumerate(input_ids):
176
- num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
177
- if num_images == 0:
178
- cur_image_features = image_features[cur_image_idx]
179
- cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
180
- cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
181
- new_input_embeds.append(cur_input_embeds)
182
- new_labels.append(labels[batch_idx])
183
- cur_image_idx += 1
184
- continue
185
-
186
- image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
187
- cur_input_ids_noim = []
188
- cur_labels = labels[batch_idx]
189
- cur_labels_noim = []
190
- for i in range(len(image_token_indices) - 1):
191
- cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
192
- cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
193
-
194
- split_sizes = [x.shape[0] for x in cur_labels_noim]
195
- cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
196
- cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
197
- cur_new_input_embeds = []
198
- cur_new_labels = []
199
-
200
- for i in range(num_images + 1):
201
- cur_new_input_embeds.append(cur_input_embeds_no_im[i])
202
- cur_new_labels.append(cur_labels_noim[i])
203
- if i < num_images:
204
- cur_image_features = image_features[cur_image_idx]
205
- cur_image_idx += 1
206
- cur_new_input_embeds.append(cur_image_features)
207
- cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
208
-
209
- cur_new_input_embeds = torch.cat(cur_new_input_embeds)
210
- cur_new_labels = torch.cat(cur_new_labels)
211
-
212
- new_input_embeds.append(cur_new_input_embeds)
213
- new_labels.append(cur_new_labels)
214
-
215
- # Truncate sequences to max length as image embeddings can make the sequence longer
216
- tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
217
- if tokenizer_model_max_length is not None:
218
- new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
219
- new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
220
-
221
- # Combine them
222
- max_len = max(x.shape[0] for x in new_input_embeds)
223
- batch_size = len(new_input_embeds)
224
-
225
- new_input_embeds_padded = []
226
- new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
227
- attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
228
- position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
229
-
230
- for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
231
- cur_len = cur_new_embed.shape[0]
232
- if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
233
- new_input_embeds_padded.append(torch.cat((
234
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
235
- cur_new_embed
236
- ), dim=0))
237
- if cur_len > 0:
238
- new_labels_padded[i, -cur_len:] = cur_new_labels
239
- attention_mask[i, -cur_len:] = True
240
- position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
241
- else:
242
- new_input_embeds_padded.append(torch.cat((
243
- cur_new_embed,
244
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
245
- ), dim=0))
246
- if cur_len > 0:
247
- new_labels_padded[i, :cur_len] = cur_new_labels
248
- attention_mask[i, :cur_len] = True
249
- position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
250
-
251
- new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
252
-
253
- if _labels is None:
254
- new_labels = None
255
- else:
256
- new_labels = new_labels_padded
257
-
258
- if _attention_mask is None:
259
- attention_mask = None
260
- else:
261
- attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
262
-
263
- if _position_ids is None:
264
- position_ids = None
265
- return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
266
-
267
- def initialize_vision_tokenizer(self, model_args, tokenizer):
268
- if model_args.mm_use_im_patch_token:
269
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
270
- self.resize_token_embeddings(len(tokenizer))
271
-
272
- if model_args.mm_use_im_start_end:
273
- num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
274
- self.resize_token_embeddings(len(tokenizer))
275
-
276
- if num_new_tokens > 0:
277
- input_embeddings = self.get_input_embeddings().weight.data
278
- output_embeddings = self.get_output_embeddings().weight.data
279
-
280
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
281
- dim=0, keepdim=True)
282
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
283
- dim=0, keepdim=True)
284
-
285
- input_embeddings[-num_new_tokens:] = input_embeddings_avg
286
- output_embeddings[-num_new_tokens:] = output_embeddings_avg
287
-
288
- if model_args.tune_mm_mlp_adapter:
289
- for p in self.get_input_embeddings().parameters():
290
- p.requires_grad = True
291
- for p in self.get_output_embeddings().parameters():
292
- p.requires_grad = False
293
-
294
- if model_args.pretrain_mm_mlp_adapter:
295
- mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
296
- embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
297
- assert num_new_tokens == 2
298
- if input_embeddings.shape == embed_tokens_weight.shape:
299
- input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
300
- elif embed_tokens_weight.shape[0] == num_new_tokens:
301
- input_embeddings[-num_new_tokens:] = embed_tokens_weight
302
- else:
303
- raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
304
- elif model_args.mm_use_im_patch_token:
305
- if model_args.tune_mm_mlp_adapter:
306
- for p in self.get_input_embeddings().parameters():
307
- p.requires_grad = False
308
- for p in self.get_output_embeddings().parameters():
309
- p.requires_grad = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/model/multimodal_encoder/builder.py DELETED
@@ -1,9 +0,0 @@
1
- import os
2
- from .clip_encoder import CLIPVisionTower
3
-
4
- def build_vision_tower(vision_tower_cfg, **kwargs):
5
- vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
6
- is_absolute_path_exists = os.path.exists(vision_tower)
7
- if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
8
- return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
9
-
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/model/multimodal_encoder/clip_encoder.py DELETED
@@ -1,78 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
-
6
-
7
- class CLIPVisionTower(nn.Module):
8
- def __init__(self, vision_tower, args, delay_load=False):
9
- super().__init__()
10
-
11
- self.is_loaded = False
12
-
13
- self.vision_tower_name = vision_tower
14
- self.select_layer = args.mm_vision_select_layer
15
- self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
-
17
- if not delay_load:
18
- self.load_model()
19
- else:
20
- self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
21
-
22
- def load_model(self):
23
- self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
24
- self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
25
- self.vision_tower.requires_grad_(False)
26
-
27
- self.is_loaded = True
28
-
29
- def feature_select(self, image_forward_outs):
30
- image_features = image_forward_outs.hidden_states[self.select_layer]
31
- if self.select_feature == 'patch':
32
- image_features = image_features[:, 1:]
33
- elif self.select_feature == 'cls_patch':
34
- image_features = image_features
35
- else:
36
- raise ValueError(f'Unexpected select feature: {self.select_feature}')
37
- return image_features
38
-
39
- @torch.no_grad()
40
- def forward(self, images):
41
- if type(images) is list:
42
- image_features = []
43
- for image in images:
44
- image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
45
- image_feature = self.feature_select(image_forward_out).to(image.dtype)
46
- image_features.append(image_feature)
47
- else:
48
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
49
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
50
-
51
- return image_features
52
-
53
- @property
54
- def dummy_feature(self):
55
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
56
-
57
- @property
58
- def dtype(self):
59
- return self.vision_tower.dtype
60
-
61
- @property
62
- def device(self):
63
- return self.vision_tower.device
64
-
65
- @property
66
- def config(self):
67
- if self.is_loaded:
68
- return self.vision_tower.config
69
- else:
70
- return self.cfg_only
71
-
72
- @property
73
- def hidden_size(self):
74
- return self.config.hidden_size
75
-
76
- @property
77
- def num_patches(self):
78
- return (self.config.image_size // self.config.patch_size) ** 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/model/multimodal_projector/builder.py DELETED
@@ -1,51 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import re
4
-
5
-
6
- class IdentityMap(nn.Module):
7
- def __init__(self):
8
- super().__init__()
9
-
10
- def forward(self, x, *args, **kwargs):
11
- return x
12
-
13
- @property
14
- def config(self):
15
- return {"mm_projector_type": 'identity'}
16
-
17
-
18
- class SimpleResBlock(nn.Module):
19
- def __init__(self, channels):
20
- super().__init__()
21
- self.pre_norm = nn.LayerNorm(channels)
22
-
23
- self.proj = nn.Sequential(
24
- nn.Linear(channels, channels),
25
- nn.GELU(),
26
- nn.Linear(channels, channels)
27
- )
28
- def forward(self, x):
29
- x = self.pre_norm(x)
30
- return x + self.proj(x)
31
-
32
-
33
- def build_vision_projector(config, delay_load=False, **kwargs):
34
- projector_type = getattr(config, 'mm_projector_type', 'linear')
35
-
36
- if projector_type == 'linear':
37
- return nn.Linear(config.mm_hidden_size, config.hidden_size)
38
-
39
- mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40
- if mlp_gelu_match:
41
- mlp_depth = int(mlp_gelu_match.group(1))
42
- modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43
- for _ in range(1, mlp_depth):
44
- modules.append(nn.GELU())
45
- modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46
- return nn.Sequential(*modules)
47
-
48
- if projector_type == 'identity':
49
- return IdentityMap()
50
-
51
- raise ValueError(f'Unknown projector type: {projector_type}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/serve/__init__.py DELETED
File without changes
LLaVA-Med/llava/serve/cli.py DELETED
@@ -1,125 +0,0 @@
1
- import argparse
2
- import torch
3
-
4
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
- from llava.conversation import conv_templates, SeparatorStyle
6
- from llava.model.builder import load_pretrained_model
7
- from llava.utils import disable_torch_init
8
- from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9
-
10
- from PIL import Image
11
-
12
- import requests
13
- from PIL import Image
14
- from io import BytesIO
15
- from transformers import TextStreamer
16
-
17
-
18
- def load_image(image_file):
19
- if image_file.startswith('http://') or image_file.startswith('https://'):
20
- response = requests.get(image_file)
21
- image = Image.open(BytesIO(response.content)).convert('RGB')
22
- else:
23
- image = Image.open(image_file).convert('RGB')
24
- return image
25
-
26
-
27
- def main(args):
28
- # Model
29
- disable_torch_init()
30
-
31
- model_name = get_model_name_from_path(args.model_path)
32
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33
-
34
- if 'llama-2' in model_name.lower():
35
- conv_mode = "llava_llama_2"
36
- elif "v1" in model_name.lower():
37
- conv_mode = "llava_v1"
38
- elif "mpt" in model_name.lower():
39
- conv_mode = "mpt"
40
- else:
41
- conv_mode = "llava_v0"
42
- conv_mode = "mistral_instruct"
43
-
44
- if args.conv_mode is not None and conv_mode != args.conv_mode:
45
- print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
46
- else:
47
- args.conv_mode = conv_mode
48
-
49
- conv = conv_templates[args.conv_mode].copy()
50
- if "mpt" in model_name.lower():
51
- roles = ('user', 'assistant')
52
- else:
53
- roles = conv.roles
54
-
55
- image = load_image(args.image_file)
56
- # Similar operation in model_worker.py
57
- image_tensor = process_images([image], image_processor, model.config)
58
- if type(image_tensor) is list:
59
- image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
60
- else:
61
- image_tensor = image_tensor.to(model.device, dtype=torch.float16)
62
-
63
- while True:
64
- try:
65
- inp = input(f"{roles[0]}: ")
66
- except EOFError:
67
- inp = ""
68
- if not inp:
69
- print("exit...")
70
- break
71
-
72
- print(f"{roles[1]}: ", end="")
73
-
74
- if image is not None:
75
- # first message
76
- if model.config.mm_use_im_start_end:
77
- inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
78
- else:
79
- inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
80
- conv.append_message(conv.roles[0], inp)
81
- image = None
82
- else:
83
- # later messages
84
- conv.append_message(conv.roles[0], inp)
85
- conv.append_message(conv.roles[1], None)
86
- prompt = conv.get_prompt()
87
-
88
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
89
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
90
- keywords = [stop_str]
91
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
92
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
93
-
94
- with torch.inference_mode():
95
- output_ids = model.generate(
96
- input_ids,
97
- images=image_tensor,
98
- do_sample=True if args.temperature > 0 else False,
99
- temperature=args.temperature,
100
- max_new_tokens=args.max_new_tokens,
101
- streamer=streamer,
102
- use_cache=True,
103
- stopping_criteria=[stopping_criteria])
104
-
105
- outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
106
- conv.messages[-1][-1] = outputs
107
-
108
- if args.debug:
109
- print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
110
-
111
-
112
- if __name__ == "__main__":
113
- parser = argparse.ArgumentParser()
114
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
115
- parser.add_argument("--model-base", type=str, default=None)
116
- parser.add_argument("--image-file", type=str, required=True)
117
- parser.add_argument("--device", type=str, default="cuda")
118
- parser.add_argument("--conv-mode", type=str, default=None)
119
- parser.add_argument("--temperature", type=float, default=0.2)
120
- parser.add_argument("--max-new-tokens", type=int, default=512)
121
- parser.add_argument("--load-8bit", action="store_true")
122
- parser.add_argument("--load-4bit", action="store_true")
123
- parser.add_argument("--debug", action="store_true")
124
- args = parser.parse_args()
125
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/serve/controller.py DELETED
@@ -1,298 +0,0 @@
1
- """
2
- A controller manages distributed workers.
3
- It sends worker addresses to clients.
4
- """
5
- import argparse
6
- import asyncio
7
- import dataclasses
8
- from enum import Enum, auto
9
- import json
10
- import logging
11
- import time
12
- from typing import List, Union
13
- import threading
14
-
15
- from fastapi import FastAPI, Request
16
- from fastapi.responses import StreamingResponse
17
- import numpy as np
18
- import requests
19
- import uvicorn
20
-
21
- from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
- from llava.utils import build_logger, server_error_msg
23
-
24
-
25
- logger = build_logger("controller", "controller.log")
26
-
27
-
28
- class DispatchMethod(Enum):
29
- LOTTERY = auto()
30
- SHORTEST_QUEUE = auto()
31
-
32
- @classmethod
33
- def from_str(cls, name):
34
- if name == "lottery":
35
- return cls.LOTTERY
36
- elif name == "shortest_queue":
37
- return cls.SHORTEST_QUEUE
38
- else:
39
- raise ValueError(f"Invalid dispatch method")
40
-
41
-
42
- @dataclasses.dataclass
43
- class WorkerInfo:
44
- model_names: List[str]
45
- speed: int
46
- queue_length: int
47
- check_heart_beat: bool
48
- last_heart_beat: str
49
-
50
-
51
- def heart_beat_controller(controller):
52
- while True:
53
- time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
- controller.remove_stable_workers_by_expiration()
55
-
56
-
57
- class Controller:
58
- def __init__(self, dispatch_method: str):
59
- # Dict[str -> WorkerInfo]
60
- self.worker_info = {}
61
- self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
-
63
- self.heart_beat_thread = threading.Thread(
64
- target=heart_beat_controller, args=(self,))
65
- self.heart_beat_thread.start()
66
-
67
- logger.info("Init controller")
68
-
69
- def register_worker(self, worker_name: str, check_heart_beat: bool,
70
- worker_status: dict):
71
- if worker_name not in self.worker_info:
72
- logger.info(f"Register a new worker: {worker_name}")
73
- else:
74
- logger.info(f"Register an existing worker: {worker_name}")
75
-
76
- if not worker_status:
77
- worker_status = self.get_worker_status(worker_name)
78
- if not worker_status:
79
- return False
80
-
81
- self.worker_info[worker_name] = WorkerInfo(
82
- worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
- check_heart_beat, time.time())
84
-
85
- logger.info(f"Register done: {worker_name}, {worker_status}")
86
- return True
87
-
88
- def get_worker_status(self, worker_name: str):
89
- try:
90
- r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
- except requests.exceptions.RequestException as e:
92
- logger.error(f"Get status fails: {worker_name}, {e}")
93
- return None
94
-
95
- if r.status_code != 200:
96
- logger.error(f"Get status fails: {worker_name}, {r}")
97
- return None
98
-
99
- return r.json()
100
-
101
- def remove_worker(self, worker_name: str):
102
- del self.worker_info[worker_name]
103
-
104
- def refresh_all_workers(self):
105
- old_info = dict(self.worker_info)
106
- self.worker_info = {}
107
-
108
- for w_name, w_info in old_info.items():
109
- if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
- logger.info(f"Remove stale worker: {w_name}")
111
-
112
- def list_models(self):
113
- model_names = set()
114
-
115
- for w_name, w_info in self.worker_info.items():
116
- model_names.update(w_info.model_names)
117
-
118
- return list(model_names)
119
-
120
- def get_worker_address(self, model_name: str):
121
- if self.dispatch_method == DispatchMethod.LOTTERY:
122
- worker_names = []
123
- worker_speeds = []
124
- for w_name, w_info in self.worker_info.items():
125
- if model_name in w_info.model_names:
126
- worker_names.append(w_name)
127
- worker_speeds.append(w_info.speed)
128
- worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
- norm = np.sum(worker_speeds)
130
- if norm < 1e-4:
131
- return ""
132
- worker_speeds = worker_speeds / norm
133
- if True: # Directly return address
134
- pt = np.random.choice(np.arange(len(worker_names)),
135
- p=worker_speeds)
136
- worker_name = worker_names[pt]
137
- return worker_name
138
-
139
- # Check status before returning
140
- while True:
141
- pt = np.random.choice(np.arange(len(worker_names)),
142
- p=worker_speeds)
143
- worker_name = worker_names[pt]
144
-
145
- if self.get_worker_status(worker_name):
146
- break
147
- else:
148
- self.remove_worker(worker_name)
149
- worker_speeds[pt] = 0
150
- norm = np.sum(worker_speeds)
151
- if norm < 1e-4:
152
- return ""
153
- worker_speeds = worker_speeds / norm
154
- continue
155
- return worker_name
156
- elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
- worker_names = []
158
- worker_qlen = []
159
- for w_name, w_info in self.worker_info.items():
160
- if model_name in w_info.model_names:
161
- worker_names.append(w_name)
162
- worker_qlen.append(w_info.queue_length / w_info.speed)
163
- if len(worker_names) == 0:
164
- return ""
165
- min_index = np.argmin(worker_qlen)
166
- w_name = worker_names[min_index]
167
- self.worker_info[w_name].queue_length += 1
168
- logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
- return w_name
170
- else:
171
- raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
-
173
- def receive_heart_beat(self, worker_name: str, queue_length: int):
174
- if worker_name not in self.worker_info:
175
- logger.info(f"Receive unknown heart beat. {worker_name}")
176
- return False
177
-
178
- self.worker_info[worker_name].queue_length = queue_length
179
- self.worker_info[worker_name].last_heart_beat = time.time()
180
- logger.info(f"Receive heart beat. {worker_name}")
181
- return True
182
-
183
- def remove_stable_workers_by_expiration(self):
184
- expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
- to_delete = []
186
- for worker_name, w_info in self.worker_info.items():
187
- if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
- to_delete.append(worker_name)
189
-
190
- for worker_name in to_delete:
191
- self.remove_worker(worker_name)
192
-
193
- def worker_api_generate_stream(self, params):
194
- worker_addr = self.get_worker_address(params["model"])
195
- if not worker_addr:
196
- logger.info(f"no worker: {params['model']}")
197
- ret = {
198
- "text": server_error_msg,
199
- "error_code": 2,
200
- }
201
- yield json.dumps(ret).encode() + b"\0"
202
-
203
- try:
204
- response = requests.post(worker_addr + "/worker_generate_stream",
205
- json=params, stream=True, timeout=5)
206
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
- if chunk:
208
- yield chunk + b"\0"
209
- except requests.exceptions.RequestException as e:
210
- logger.info(f"worker timeout: {worker_addr}")
211
- ret = {
212
- "text": server_error_msg,
213
- "error_code": 3,
214
- }
215
- yield json.dumps(ret).encode() + b"\0"
216
-
217
-
218
- # Let the controller act as a worker to achieve hierarchical
219
- # management. This can be used to connect isolated sub networks.
220
- def worker_api_get_status(self):
221
- model_names = set()
222
- speed = 0
223
- queue_length = 0
224
-
225
- for w_name in self.worker_info:
226
- worker_status = self.get_worker_status(w_name)
227
- if worker_status is not None:
228
- model_names.update(worker_status["model_names"])
229
- speed += worker_status["speed"]
230
- queue_length += worker_status["queue_length"]
231
-
232
- return {
233
- "model_names": list(model_names),
234
- "speed": speed,
235
- "queue_length": queue_length,
236
- }
237
-
238
-
239
- app = FastAPI()
240
-
241
-
242
- @app.post("/register_worker")
243
- async def register_worker(request: Request):
244
- data = await request.json()
245
- controller.register_worker(
246
- data["worker_name"], data["check_heart_beat"],
247
- data.get("worker_status", None))
248
-
249
-
250
- @app.post("/refresh_all_workers")
251
- async def refresh_all_workers():
252
- models = controller.refresh_all_workers()
253
-
254
-
255
- @app.post("/list_models")
256
- async def list_models():
257
- models = controller.list_models()
258
- return {"models": models}
259
-
260
-
261
- @app.post("/get_worker_address")
262
- async def get_worker_address(request: Request):
263
- data = await request.json()
264
- addr = controller.get_worker_address(data["model"])
265
- return {"address": addr}
266
-
267
-
268
- @app.post("/receive_heart_beat")
269
- async def receive_heart_beat(request: Request):
270
- data = await request.json()
271
- exist = controller.receive_heart_beat(
272
- data["worker_name"], data["queue_length"])
273
- return {"exist": exist}
274
-
275
-
276
- @app.post("/worker_generate_stream")
277
- async def worker_api_generate_stream(request: Request):
278
- params = await request.json()
279
- generator = controller.worker_api_generate_stream(params)
280
- return StreamingResponse(generator)
281
-
282
-
283
- @app.post("/worker_get_status")
284
- async def worker_api_get_status(request: Request):
285
- return controller.worker_api_get_status()
286
-
287
-
288
- if __name__ == "__main__":
289
- parser = argparse.ArgumentParser()
290
- parser.add_argument("--host", type=str, default="localhost")
291
- parser.add_argument("--port", type=int, default=21001)
292
- parser.add_argument("--dispatch-method", type=str, choices=[
293
- "lottery", "shortest_queue"], default="shortest_queue")
294
- args = parser.parse_args()
295
- logger.info(f"args: {args}")
296
-
297
- controller = Controller(args.dispatch_method)
298
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/serve/examples/bio_patch.png DELETED
Binary file (214 kB)
 
LLaVA-Med/llava/serve/examples/extreme_ironing.jpg DELETED
Binary file (62.6 kB)
 
LLaVA-Med/llava/serve/examples/med_img_1.png DELETED
Binary file (319 kB)
 
LLaVA-Med/llava/serve/examples/synpic32933.jpg DELETED
Binary file (69.8 kB)
 
LLaVA-Med/llava/serve/examples/synpic42202.jpg DELETED
Binary file (86.1 kB)
 
LLaVA-Med/llava/serve/examples/waterview.jpg DELETED
Binary file (95.5 kB)
 
LLaVA-Med/llava/serve/examples/xy_chromosome.jpg DELETED
Binary file (53.9 kB)
 
LLaVA-Med/llava/serve/gradio_web_server.py DELETED
@@ -1,477 +0,0 @@
1
- import argparse
2
- import datetime
3
- import json
4
- import os
5
- import time
6
-
7
- import gradio as gr
8
- import requests
9
-
10
- from llava.conversation import (default_conversation, conv_templates,
11
- SeparatorStyle)
12
- from llava.constants import LOGDIR
13
- from llava.utils import (build_logger, server_error_msg,
14
- violates_moderation, moderation_msg)
15
- import hashlib
16
-
17
-
18
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
-
20
- headers = {"User-Agent": "LLaVA-Med Client"}
21
-
22
- no_change_btn = gr.Button.update()
23
- enable_btn = gr.Button.update(interactive=True)
24
- disable_btn = gr.Button.update(interactive=False)
25
-
26
- priority = {
27
- "vicuna-13b": "aaaaaaa",
28
- "koala-13b": "aaaaaab",
29
- }
30
-
31
-
32
- def get_conv_log_filename():
33
- t = datetime.datetime.now()
34
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
35
- return name
36
-
37
-
38
- def get_model_list():
39
- ret = requests.post(args.controller_url + "/refresh_all_workers")
40
- assert ret.status_code == 200
41
- ret = requests.post(args.controller_url + "/list_models")
42
- models = ret.json()["models"]
43
- models.sort(key=lambda x: priority.get(x, x))
44
- logger.info(f"Models: {models}")
45
- return models
46
-
47
-
48
- get_window_url_params = """
49
- function() {
50
- const params = new URLSearchParams(window.location.search);
51
- url_params = Object.fromEntries(params);
52
- console.log(url_params);
53
- return url_params;
54
- }
55
- """
56
-
57
-
58
- def load_demo(url_params, request: gr.Request):
59
- logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60
-
61
- dropdown_update = gr.Dropdown.update(visible=True)
62
- if "model" in url_params:
63
- model = url_params["model"]
64
- if model in models:
65
- dropdown_update = gr.Dropdown.update(
66
- value=model, visible=True)
67
-
68
- state = default_conversation.copy()
69
- return state, dropdown_update
70
-
71
-
72
- def load_demo_refresh_model_list(request: gr.Request):
73
- logger.info(f"load_demo. ip: {request.client.host}")
74
- models = get_model_list()
75
- state = default_conversation.copy()
76
- dropdown_update = gr.Dropdown.update(
77
- choices=models,
78
- value=models[0] if len(models) > 0 else ""
79
- )
80
- return state, dropdown_update
81
-
82
-
83
- def vote_last_response(state, vote_type, model_selector, request: gr.Request):
84
- with open(get_conv_log_filename(), "a") as fout:
85
- data = {
86
- "tstamp": round(time.time(), 4),
87
- "type": vote_type,
88
- "model": model_selector,
89
- "state": state.dict(),
90
- "ip": request.client.host,
91
- }
92
- fout.write(json.dumps(data) + "\n")
93
-
94
-
95
- def upvote_last_response(state, model_selector, request: gr.Request):
96
- logger.info(f"upvote. ip: {request.client.host}")
97
- vote_last_response(state, "upvote", model_selector, request)
98
- return ("",) + (disable_btn,) * 3
99
-
100
-
101
- def downvote_last_response(state, model_selector, request: gr.Request):
102
- logger.info(f"downvote. ip: {request.client.host}")
103
- vote_last_response(state, "downvote", model_selector, request)
104
- return ("",) + (disable_btn,) * 3
105
-
106
-
107
- def flag_last_response(state, model_selector, request: gr.Request):
108
- logger.info(f"flag. ip: {request.client.host}")
109
- vote_last_response(state, "flag", model_selector, request)
110
- return ("",) + (disable_btn,) * 3
111
-
112
-
113
- def regenerate(state, image_process_mode, request: gr.Request):
114
- logger.info(f"regenerate. ip: {request.client.host}")
115
- state.messages[-1][-1] = None
116
- prev_human_msg = state.messages[-2]
117
- if type(prev_human_msg[1]) in (tuple, list):
118
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
119
- state.skip_next = False
120
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
121
-
122
-
123
- def clear_history(request: gr.Request):
124
- logger.info(f"clear_history. ip: {request.client.host}")
125
- state = default_conversation.copy()
126
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
127
-
128
-
129
- def add_text(state, text, image, image_process_mode, request: gr.Request):
130
- logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
131
- if len(text) <= 0 and image is None:
132
- state.skip_next = True
133
- return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
134
- if args.moderate:
135
- flagged = violates_moderation(text)
136
- if flagged:
137
- state.skip_next = True
138
- return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
139
- no_change_btn,) * 5
140
-
141
- text = text[:1536] # Hard cut-off
142
- if image is not None:
143
- text = text[:1200] # Hard cut-off for images
144
- if '<image>' not in text:
145
- # text = '<Image><image></Image>' + text
146
- text = text + '\n<image>'
147
- text = (text, image, image_process_mode)
148
- if len(state.get_images(return_pil=True)) > 0:
149
- state = default_conversation.copy()
150
- state.append_message(state.roles[0], text)
151
- state.append_message(state.roles[1], None)
152
- state.skip_next = False
153
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
154
-
155
-
156
- def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
157
- logger.info(f"http_bot. ip: {request.client.host}")
158
- start_tstamp = time.time()
159
- model_name = model_selector
160
-
161
- if state.skip_next:
162
- # This generate call is skipped due to invalid inputs
163
- yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
164
- return
165
-
166
- if len(state.messages) == state.offset + 2:
167
- # First round of conversation
168
- if "llava" in model_name.lower():
169
- if 'llama-2' in model_name.lower():
170
- template_name = "llava_llama_2"
171
- elif "v1" in model_name.lower():
172
- if 'mmtag' in model_name.lower():
173
- template_name = "v1_mmtag"
174
- elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
175
- template_name = "v1_mmtag"
176
- else:
177
- template_name = "llava_v1"
178
- elif "mpt" in model_name.lower():
179
- template_name = "mpt"
180
- else:
181
- if 'mmtag' in model_name.lower():
182
- template_name = "v0_mmtag"
183
- elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
184
- template_name = "v0_mmtag"
185
- else:
186
- template_name = "llava_v0"
187
- elif "mpt" in model_name:
188
- template_name = "mpt_text"
189
- elif "llama-2" in model_name:
190
- template_name = "llama_2"
191
- else:
192
- template_name = "vicuna_v1"
193
- template_name = "mistral_instruct" # FIXME: overwrite
194
- new_state = conv_templates[template_name].copy()
195
- new_state.append_message(new_state.roles[0], state.messages[-2][1])
196
- new_state.append_message(new_state.roles[1], None)
197
- state = new_state
198
-
199
- # Query worker address
200
- controller_url = args.controller_url
201
- ret = requests.post(controller_url + "/get_worker_address",
202
- json={"model": model_name})
203
- worker_addr = ret.json()["address"]
204
- logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
205
-
206
- # No available worker
207
- if worker_addr == "":
208
- state.messages[-1][-1] = server_error_msg
209
- yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
210
- return
211
-
212
- # Construct prompt
213
- prompt = state.get_prompt()
214
-
215
- all_images = state.get_images(return_pil=True)
216
- all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
217
- for image, hash in zip(all_images, all_image_hash):
218
- t = datetime.datetime.now()
219
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
220
- if not os.path.isfile(filename):
221
- os.makedirs(os.path.dirname(filename), exist_ok=True)
222
- image.save(filename)
223
-
224
- # Make requests
225
- pload = {
226
- "model": model_name,
227
- "prompt": prompt,
228
- "temperature": float(temperature),
229
- "top_p": float(top_p),
230
- "max_new_tokens": min(int(max_new_tokens), 1536),
231
- "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
232
- "images": f'List of {len(state.get_images())} images: {all_image_hash}',
233
- }
234
- logger.info(f"==== request ====\n{pload}")
235
-
236
- pload['images'] = state.get_images()
237
-
238
- state.messages[-1][-1] = "▌"
239
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
240
-
241
- try:
242
- # Stream output
243
- response = requests.post(worker_addr + "/worker_generate_stream",
244
- headers=headers, json=pload, stream=True, timeout=10)
245
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
246
- if chunk:
247
- data = json.loads(chunk.decode())
248
- if data["error_code"] == 0:
249
- output = data["text"][len(prompt):].strip()
250
- state.messages[-1][-1] = output + "▌"
251
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
252
- else:
253
- output = data["text"] + f" (error_code: {data['error_code']})"
254
- state.messages[-1][-1] = output
255
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
256
- return
257
- time.sleep(0.03)
258
- except requests.exceptions.RequestException as e:
259
- state.messages[-1][-1] = server_error_msg
260
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
261
- return
262
-
263
- state.messages[-1][-1] = state.messages[-1][-1][:-1]
264
- yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
265
-
266
- finish_tstamp = time.time()
267
- logger.info(f"{output}")
268
-
269
- with open(get_conv_log_filename(), "a") as fout:
270
- data = {
271
- "tstamp": round(finish_tstamp, 4),
272
- "type": "chat",
273
- "model": model_name,
274
- "start": round(start_tstamp, 4),
275
- "finish": round(finish_tstamp, 4),
276
- "state": state.dict(),
277
- "images": all_image_hash,
278
- "ip": request.client.host,
279
- }
280
- fout.write(json.dumps(data) + "\n")
281
-
282
-
283
- title_markdown = ("""
284
- # 🌋 LLaVA-Med: Large Language and Vision Assistant for Medical Research
285
- [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)
286
- """)
287
-
288
- tos_markdown = ("""
289
- ### Terms of use
290
- By using this service, users are required to agree to the following terms:
291
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
292
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
293
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
294
- """)
295
-
296
-
297
- learn_more_markdown = ("""
298
- ### License
299
- The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
300
- """)
301
-
302
- block_css = """
303
-
304
- #buttons button {
305
- min-width: min(120px,100%);
306
- }
307
-
308
- """
309
-
310
- def build_demo(embed_mode):
311
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
312
- with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
313
- state = gr.State()
314
-
315
- if not embed_mode:
316
- gr.Markdown(title_markdown)
317
-
318
- with gr.Row():
319
- with gr.Column(scale=3):
320
- with gr.Row(elem_id="model_selector_row"):
321
- model_selector = gr.Dropdown(
322
- choices=models,
323
- value=models[0] if len(models) > 0 else "",
324
- interactive=True,
325
- show_label=False,
326
- container=False)
327
-
328
- imagebox = gr.Image(type="pil")
329
- image_process_mode = gr.Radio(
330
- ["Crop", "Resize", "Pad", "Default"],
331
- value="Default",
332
- label="Preprocess for non-square image", visible=False)
333
-
334
- cur_dir = os.path.dirname(os.path.abspath(__file__))
335
- gr.Examples(examples=[
336
- [f"{cur_dir}/examples/bio_patch.png", "What is this image about?"],
337
- [f"{cur_dir}/examples/med_img_1.png", "Can you describe the image in details?"],
338
- [f"{cur_dir}/examples/xy_chromosome.jpg", "Can you describe the image in details?"],
339
- [f"{cur_dir}/examples/synpic42202.jpg", "Is there evidence of an aortic aneurysm? Please choose from the following two options: [yes, no]?"], # answer" yes
340
- [f"{cur_dir}/examples/synpic32933.jpg", "What is the abnormality by the right hemidiaphragm?"], # answer: free air
341
- [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
342
- [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
343
- ], inputs=[imagebox, textbox])
344
-
345
- with gr.Accordion("Parameters", open=False) as parameter_row:
346
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
347
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
348
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
349
-
350
- with gr.Column(scale=8):
351
- chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA-Med Chatbot", height=550)
352
- with gr.Row():
353
- with gr.Column(scale=8):
354
- textbox.render()
355
- with gr.Column(scale=1, min_width=50):
356
- submit_btn = gr.Button(value="Send", variant="primary")
357
- with gr.Row(elem_id="buttons") as button_row:
358
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
359
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
360
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
361
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
362
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
363
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
364
-
365
- if not embed_mode:
366
- gr.Markdown(tos_markdown)
367
- gr.Markdown(learn_more_markdown)
368
- url_params = gr.JSON(visible=False)
369
-
370
- # Register listeners
371
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
372
- upvote_btn.click(
373
- upvote_last_response,
374
- [state, model_selector],
375
- [textbox, upvote_btn, downvote_btn, flag_btn],
376
- queue=False
377
- )
378
- downvote_btn.click(
379
- downvote_last_response,
380
- [state, model_selector],
381
- [textbox, upvote_btn, downvote_btn, flag_btn],
382
- queue=False
383
- )
384
- flag_btn.click(
385
- flag_last_response,
386
- [state, model_selector],
387
- [textbox, upvote_btn, downvote_btn, flag_btn],
388
- queue=False
389
- )
390
-
391
- regenerate_btn.click(
392
- regenerate,
393
- [state, image_process_mode],
394
- [state, chatbot, textbox, imagebox] + btn_list,
395
- queue=False
396
- ).then(
397
- http_bot,
398
- [state, model_selector, temperature, top_p, max_output_tokens],
399
- [state, chatbot] + btn_list
400
- )
401
-
402
- clear_btn.click(
403
- clear_history,
404
- None,
405
- [state, chatbot, textbox, imagebox] + btn_list,
406
- queue=False
407
- )
408
-
409
- textbox.submit(
410
- add_text,
411
- [state, textbox, imagebox, image_process_mode],
412
- [state, chatbot, textbox, imagebox] + btn_list,
413
- queue=False
414
- ).then(
415
- http_bot,
416
- [state, model_selector, temperature, top_p, max_output_tokens],
417
- [state, chatbot] + btn_list
418
- )
419
-
420
- submit_btn.click(
421
- add_text,
422
- [state, textbox, imagebox, image_process_mode],
423
- [state, chatbot, textbox, imagebox] + btn_list,
424
- queue=False
425
- ).then(
426
- http_bot,
427
- [state, model_selector, temperature, top_p, max_output_tokens],
428
- [state, chatbot] + btn_list
429
- )
430
-
431
- if args.model_list_mode == "once":
432
- demo.load(
433
- load_demo,
434
- [url_params],
435
- [state, model_selector],
436
- _js=get_window_url_params,
437
- queue=False
438
- )
439
- elif args.model_list_mode == "reload":
440
- demo.load(
441
- load_demo_refresh_model_list,
442
- None,
443
- [state, model_selector],
444
- queue=False
445
- )
446
- else:
447
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
448
-
449
- return demo
450
-
451
-
452
- if __name__ == "__main__":
453
- parser = argparse.ArgumentParser()
454
- parser.add_argument("--host", type=str, default="0.0.0.0")
455
- parser.add_argument("--port", type=int)
456
- parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
457
- parser.add_argument("--concurrency-count", type=int, default=10)
458
- parser.add_argument("--model-list-mode", type=str, default="once",
459
- choices=["once", "reload"])
460
- parser.add_argument("--share", action="store_true")
461
- parser.add_argument("--moderate", action="store_true")
462
- parser.add_argument("--embed", action="store_true")
463
- args = parser.parse_args()
464
- logger.info(f"args: {args}")
465
-
466
- models = get_model_list()
467
-
468
- logger.info(args)
469
- demo = build_demo(args.embed)
470
- demo.queue(
471
- concurrency_count=args.concurrency_count,
472
- api_open=False
473
- ).launch(
474
- server_name=args.host,
475
- server_port=args.port,
476
- share=args.share
477
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/serve/model_worker.py DELETED
@@ -1,285 +0,0 @@
1
- """
2
- A model worker executes the model.
3
- """
4
- import argparse
5
- import asyncio
6
- import json
7
- import time
8
- import threading
9
- import uuid
10
-
11
- from fastapi import FastAPI, Request, BackgroundTasks
12
- from fastapi.responses import StreamingResponse
13
- import requests
14
- import torch
15
- import uvicorn
16
- from functools import partial
17
-
18
- from llava.constants import WORKER_HEART_BEAT_INTERVAL
19
- from llava.utils import (build_logger, server_error_msg,
20
- pretty_print_semaphore)
21
- from llava.model.builder import load_pretrained_model
22
- from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
23
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
- from transformers import TextIteratorStreamer
25
- from threading import Thread
26
-
27
-
28
- GB = 1 << 30
29
-
30
- worker_id = str(uuid.uuid4())[:6]
31
- logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
32
- global_counter = 0
33
-
34
- model_semaphore = None
35
-
36
-
37
- def heart_beat_worker(controller):
38
-
39
- while True:
40
- time.sleep(WORKER_HEART_BEAT_INTERVAL)
41
- controller.send_heart_beat()
42
-
43
-
44
- class ModelWorker:
45
- def __init__(self, controller_addr, worker_addr,
46
- worker_id, no_register,
47
- model_path, model_base, model_name,
48
- load_8bit, load_4bit, device):
49
- self.controller_addr = controller_addr
50
- self.worker_addr = worker_addr
51
- self.worker_id = worker_id
52
- if model_path.endswith("/"):
53
- model_path = model_path[:-1]
54
- if model_name is None:
55
- model_paths = model_path.split("/")
56
- if model_paths[-1].startswith('checkpoint-'):
57
- self.model_name = model_paths[-2] + "_" + model_paths[-1]
58
- else:
59
- self.model_name = model_paths[-1]
60
- else:
61
- self.model_name = model_name
62
-
63
- self.device = device
64
- logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
65
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
66
- model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
67
- self.is_multimodal = 'llava' in self.model_name.lower()
68
-
69
- if not no_register:
70
- self.register_to_controller()
71
- self.heart_beat_thread = threading.Thread(
72
- target=heart_beat_worker, args=(self,))
73
- self.heart_beat_thread.start()
74
-
75
- def register_to_controller(self):
76
- logger.info("Register to controller")
77
-
78
- url = self.controller_addr + "/register_worker"
79
- data = {
80
- "worker_name": self.worker_addr,
81
- "check_heart_beat": True,
82
- "worker_status": self.get_status()
83
- }
84
- r = requests.post(url, json=data)
85
- assert r.status_code == 200
86
-
87
- def send_heart_beat(self):
88
- logger.info(f"Send heart beat. Models: {[self.model_name]}. "
89
- f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
90
- f"global_counter: {global_counter}")
91
-
92
- url = self.controller_addr + "/receive_heart_beat"
93
-
94
- while True:
95
- try:
96
- ret = requests.post(url, json={
97
- "worker_name": self.worker_addr,
98
- "queue_length": self.get_queue_length()}, timeout=5)
99
- exist = ret.json()["exist"]
100
- break
101
- except requests.exceptions.RequestException as e:
102
- logger.error(f"heart beat error: {e}")
103
- time.sleep(5)
104
-
105
- if not exist:
106
- self.register_to_controller()
107
-
108
- def get_queue_length(self):
109
- if model_semaphore is None:
110
- return 0
111
- else:
112
- return args.limit_model_concurrency - model_semaphore._value + (len(
113
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
114
-
115
- def get_status(self):
116
- return {
117
- "model_names": [self.model_name],
118
- "speed": 1,
119
- "queue_length": self.get_queue_length(),
120
- }
121
-
122
- @torch.inference_mode()
123
- def generate_stream(self, params):
124
- tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
125
-
126
- prompt = params["prompt"]
127
- ori_prompt = prompt
128
- images = params.get("images", None)
129
- num_image_tokens = 0
130
- if images is not None and len(images) > 0 and self.is_multimodal:
131
- if len(images) > 0:
132
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
133
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
134
-
135
- images = [load_image_from_base64(image) for image in images]
136
- images = process_images(images, image_processor, model.config)
137
-
138
- if type(images) is list:
139
- images = [image.to(self.model.device, dtype=torch.float16) for image in images]
140
- else:
141
- images = images.to(self.model.device, dtype=torch.float16)
142
-
143
- replace_token = DEFAULT_IMAGE_TOKEN
144
- if getattr(self.model.config, 'mm_use_im_start_end', False):
145
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
146
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
147
-
148
- num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
149
- else:
150
- images = None
151
- image_args = {"images": images}
152
- else:
153
- images = None
154
- image_args = {}
155
-
156
- temperature = float(params.get("temperature", 1.0))
157
- top_p = float(params.get("top_p", 1.0))
158
- max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
159
- max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
160
- stop_str = params.get("stop", None)
161
- do_sample = True if temperature > 0.001 else False
162
-
163
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
164
- keywords = [stop_str]
165
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
166
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
167
-
168
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
169
-
170
- if max_new_tokens < 1:
171
- yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
172
- return
173
-
174
- thread = Thread(target=model.generate, kwargs=dict(
175
- inputs=input_ids,
176
- do_sample=do_sample,
177
- temperature=temperature,
178
- top_p=top_p,
179
- max_new_tokens=max_new_tokens,
180
- streamer=streamer,
181
- stopping_criteria=[stopping_criteria],
182
- use_cache=True,
183
- **image_args
184
- ))
185
- thread.start()
186
-
187
- generated_text = ori_prompt
188
- for new_text in streamer:
189
- generated_text += new_text
190
- if generated_text.endswith(stop_str):
191
- generated_text = generated_text[:-len(stop_str)]
192
- yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
193
-
194
- def generate_stream_gate(self, params):
195
- try:
196
- for x in self.generate_stream(params):
197
- yield x
198
- except ValueError as e:
199
- print("Caught ValueError:", e)
200
- ret = {
201
- "text": server_error_msg,
202
- "error_code": 1,
203
- }
204
- yield json.dumps(ret).encode() + b"\0"
205
- except torch.cuda.CudaError as e:
206
- print("Caught torch.cuda.CudaError:", e)
207
- ret = {
208
- "text": server_error_msg,
209
- "error_code": 1,
210
- }
211
- yield json.dumps(ret).encode() + b"\0"
212
- except Exception as e:
213
- print("Caught Unknown Error", e)
214
- ret = {
215
- "text": server_error_msg,
216
- "error_code": 1,
217
- }
218
- yield json.dumps(ret).encode() + b"\0"
219
-
220
-
221
- app = FastAPI()
222
-
223
-
224
- def release_model_semaphore(fn=None):
225
- model_semaphore.release()
226
- if fn is not None:
227
- fn()
228
-
229
-
230
- @app.post("/worker_generate_stream")
231
- async def generate_stream(request: Request):
232
- global model_semaphore, global_counter
233
- global_counter += 1
234
- params = await request.json()
235
-
236
- if model_semaphore is None:
237
- model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
238
- await model_semaphore.acquire()
239
- worker.send_heart_beat()
240
- generator = worker.generate_stream_gate(params)
241
- background_tasks = BackgroundTasks()
242
- background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
243
- return StreamingResponse(generator, background=background_tasks)
244
-
245
-
246
- @app.post("/worker_get_status")
247
- async def get_status(request: Request):
248
- return worker.get_status()
249
-
250
-
251
- if __name__ == "__main__":
252
- parser = argparse.ArgumentParser()
253
- parser.add_argument("--host", type=str, default="localhost")
254
- parser.add_argument("--port", type=int, default=21002)
255
- parser.add_argument("--worker-address", type=str,
256
- default="http://localhost:21002")
257
- parser.add_argument("--controller-address", type=str,
258
- default="http://localhost:21001")
259
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
260
- parser.add_argument("--model-base", type=str, default=None)
261
- parser.add_argument("--model-name", type=str)
262
- parser.add_argument("--device", type=str, default="cuda")
263
- parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
264
- parser.add_argument("--limit-model-concurrency", type=int, default=5)
265
- parser.add_argument("--stream-interval", type=int, default=1)
266
- parser.add_argument("--no-register", action="store_true")
267
- parser.add_argument("--load-8bit", action="store_true")
268
- parser.add_argument("--load-4bit", action="store_true")
269
- args = parser.parse_args()
270
- logger.info(f"args: {args}")
271
-
272
- if args.multi_modal:
273
- logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
274
-
275
- worker = ModelWorker(args.controller_address,
276
- args.worker_address,
277
- worker_id,
278
- args.no_register,
279
- args.model_path,
280
- args.model_base,
281
- args.model_name,
282
- args.load_8bit,
283
- args.load_4bit,
284
- args.device)
285
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/serve/register_worker.py DELETED
@@ -1,26 +0,0 @@
1
- """
2
- Manually register workers.
3
-
4
- Usage:
5
- python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6
- """
7
-
8
- import argparse
9
-
10
- import requests
11
-
12
- if __name__ == "__main__":
13
- parser = argparse.ArgumentParser()
14
- parser.add_argument("--controller-address", type=str)
15
- parser.add_argument("--worker-name", type=str)
16
- parser.add_argument("--check-heart-beat", action="store_true")
17
- args = parser.parse_args()
18
-
19
- url = args.controller_address + "/register_worker"
20
- data = {
21
- "worker_name": args.worker_name,
22
- "check_heart_beat": args.check_heart_beat,
23
- "worker_status": None,
24
- }
25
- r = requests.post(url, json=data)
26
- assert r.status_code == 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LLaVA-Med/llava/serve/test_message.py DELETED
@@ -1,62 +0,0 @@
1
- import argparse
2
- import json
3
-
4
- import requests
5
-
6
- from llava.conversation import conv_templates
7
-
8
-
9
- def main():
10
- if args.worker_address:
11
- worker_addr = args.worker_address
12
- else:
13
- controller_addr = args.controller_address
14
- ret = requests.post(controller_addr + "/refresh_all_workers")
15
- ret = requests.post(controller_addr + "/list_models")
16
- models = ret.json()["models"]
17
- models.sort()
18
- print(f"Models: {models}")
19
-
20
- ret = requests.post(controller_addr + "/get_worker_address",
21
- json={"model": args.model_name})
22
- worker_addr = ret.json()["address"]
23
- print(f"worker_addr: {worker_addr}")
24
-
25
- if worker_addr == "":
26
- return
27
-
28
- conv = conv_templates["mistral_instruct"].copy()
29
- conv.append_message(conv.roles[0], args.message)
30
- prompt = conv.get_prompt()
31
-
32
- headers = {"User-Agent": "LLaVA Client"}
33
- pload = {
34
- "model": args.model_name,
35
- "prompt": prompt,
36
- "max_new_tokens": args.max_new_tokens,
37
- "temperature": 0.7,
38
- "stop": conv.sep2,
39
- }
40
- response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41
- json=pload, stream=True)
42
-
43
- print(prompt, end="")
44
- for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45
- if chunk:
46
- data = json.loads(chunk.decode("utf-8"))
47
- output = data["text"].split("[/INST]")[-1]
48
- print(output, end="\r")
49
- print("")
50
-
51
-
52
- if __name__ == "__main__":
53
- parser = argparse.ArgumentParser()
54
- parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55
- parser.add_argument("--worker-address", type=str)
56
- parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57
- parser.add_argument("--max-new-tokens", type=int, default=256)
58
- parser.add_argument("--message", type=str, default=
59
- "Tell me a story with more than 1000 words.")
60
- args = parser.parse_args()
61
-
62
- main()