Delete LLaVA-Med
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LLaVA-Med/.gitignore +0 -3
- LLaVA-Med/CODE_OF_CONDUCT.md +0 -9
- LLaVA-Med/LICENSE +0 -62
- LLaVA-Med/README.md +0 -260
- LLaVA-Med/SECURITY.md +0 -41
- LLaVA-Med/SUPPORT.md +0 -25
- LLaVA-Med/bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl +0 -3
- LLaVA-Med/data/eval/llava_med_eval_qa50_qa.jsonl +0 -0
- LLaVA-Med/docs/llava_med_performance.md +0 -31
- LLaVA-Med/download_data.sh +0 -35
- LLaVA-Med/images/llava_logo.png +0 -0
- LLaVA-Med/images/llava_med_chat.png +0 -0
- LLaVA-Med/images/llava_med_chat_example1.png +0 -0
- LLaVA-Med/images/llava_med_chat_example2.png +0 -0
- LLaVA-Med/images/llava_med_dataset.png +0 -0
- LLaVA-Med/images/llava_med_logo.png +0 -0
- LLaVA-Med/images/llava_med_pipeline.png +0 -0
- LLaVA-Med/images/llava_med_vqa.png +0 -0
- LLaVA-Med/llava/__init__.py +0 -0
- LLaVA-Med/llava/constants.py +0 -13
- LLaVA-Med/llava/conversation.py +0 -439
- LLaVA-Med/llava/eval/eval_multimodal_chat_gpt_score.py +0 -112
- LLaVA-Med/llava/eval/llm.py +0 -134
- LLaVA-Med/llava/eval/model_vqa.py +0 -109
- LLaVA-Med/llava/eval/run_llava.py +0 -145
- LLaVA-Med/llava/eval/summarize_gpt_review.py +0 -47
- LLaVA-Med/llava/eval/util.py +0 -9
- LLaVA-Med/llava/mm_utils.py +0 -110
- LLaVA-Med/llava/model/__init__.py +0 -1
- LLaVA-Med/llava/model/builder.py +0 -83
- LLaVA-Med/llava/model/builders.py +0 -152
- LLaVA-Med/llava/model/language_model/llava_mistral.py +0 -143
- LLaVA-Med/llava/model/llava_arch.py +0 -309
- LLaVA-Med/llava/model/multimodal_encoder/builder.py +0 -9
- LLaVA-Med/llava/model/multimodal_encoder/clip_encoder.py +0 -78
- LLaVA-Med/llava/model/multimodal_projector/builder.py +0 -51
- LLaVA-Med/llava/serve/__init__.py +0 -0
- LLaVA-Med/llava/serve/cli.py +0 -125
- LLaVA-Med/llava/serve/controller.py +0 -298
- LLaVA-Med/llava/serve/examples/bio_patch.png +0 -0
- LLaVA-Med/llava/serve/examples/extreme_ironing.jpg +0 -0
- LLaVA-Med/llava/serve/examples/med_img_1.png +0 -0
- LLaVA-Med/llava/serve/examples/synpic32933.jpg +0 -0
- LLaVA-Med/llava/serve/examples/synpic42202.jpg +0 -0
- LLaVA-Med/llava/serve/examples/waterview.jpg +0 -0
- LLaVA-Med/llava/serve/examples/xy_chromosome.jpg +0 -0
- LLaVA-Med/llava/serve/gradio_web_server.py +0 -477
- LLaVA-Med/llava/serve/model_worker.py +0 -285
- LLaVA-Med/llava/serve/register_worker.py +0 -26
- LLaVA-Med/llava/serve/test_message.py +0 -62
LLaVA-Med/.gitignore
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
__pycache__
|
2 |
-
*.pyc
|
3 |
-
*.egg-info
|
|
|
|
|
|
|
|
LLaVA-Med/CODE_OF_CONDUCT.md
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# Microsoft Open Source Code of Conduct
|
2 |
-
|
3 |
-
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
4 |
-
|
5 |
-
Resources:
|
6 |
-
|
7 |
-
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
8 |
-
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
9 |
-
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/LICENSE
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
MICROSOFT RESEARCH LICENSE TERMS
|
2 |
-
|
3 |
-
IF YOU LIVE IN THE UNITED STATES, PLEASE READ THE “BINDING ARBITRATION AND CLASS ACTION WAIVER” SECTION BELOW. IT AFFECTS HOW DISPUTES ARE RESOLVED.
|
4 |
-
|
5 |
-
These license terms are an agreement between you and Microsoft Corporation (or one of its affiliates). They apply to the source code, object code, machine learning models, or data (collectively “Materials”) that accompany this license. IF YOU COMPLY WITH THESE LICENSE TERMS, YOU HAVE THE RIGHTS BELOW. BY USING THE MATERIALS, YOU ACCEPT THESE TERMS.
|
6 |
-
|
7 |
-
1) INSTALLATION AND USE RIGHTS TO THE MATERIALS.
|
8 |
-
|
9 |
-
Subject to the terms of this agreement, you have the below rights, if applicable, to use the Materials solely for non-commercial, non-revenue generating, research purposes:
|
10 |
-
|
11 |
-
a) Source Code. If source code is included, you may use and modify the source code, but you may not distribute the source code.
|
12 |
-
b) Object Code. If object code is included, you may use the object code, but you may not distribute the object code.
|
13 |
-
c) Models. If machine learning model(s) are included, you may use the model(s), but you may not distribute the models.
|
14 |
-
d) Data. If data is included, you may use and modify the data, but your use and modification must be consistent with the consent under which the data was provided and/or gathered and you may not distribute the data or your modifications to the data.
|
15 |
-
|
16 |
-
2) SCOPE OF LICENSE. The Materials are licensed, not sold. Microsoft reserves all other rights. Unless applicable law gives you more rights despite this limitation, you will not (and have no right to):
|
17 |
-
|
18 |
-
a) work around any technical limitations in the Materials that only allow you to use it in certain ways;
|
19 |
-
b) reverse engineer, decompile or disassemble the Materials;
|
20 |
-
c) remove, minimize, block, or modify any notices of Microsoft or its suppliers in the Materials;
|
21 |
-
d) use the Materials in any way that is against the law or to create or propagate malware; or
|
22 |
-
e) share, publish, distribute or lend the Materials, provide the Materials as a stand-alone hosted solution for others to use, or transfer the Materials or this agreement to any third party.
|
23 |
-
|
24 |
-
3) PERSONAL DATA. If the data (set forth in Section 1(c) above) includes or is found to include any data that enables any ability to identify an individual (“Personal Data”), you will not use such Personal Data for any purpose other than was authorized and consented to by the data subject/research participant. You will not use Personal Data to contact any person. You will keep Personal Data in strict confidence. You will not share any Personal Data that is collected or in your possession with any third party for any reason and as required under the original consent agreement. Further, you will destroy the Personal Data and any backup or copies, immediately upon the completion of your research.
|
25 |
-
|
26 |
-
4) LICENSE TO MICROSOFT. Notwithstanding the limitations in Section 1, you may distribute your modifications back to Microsoft, and if you do provide Microsoft with modifications of the Materials, you hereby grant Microsoft, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer such modifications and derivatives for any purpose.
|
27 |
-
|
28 |
-
5) PUBLICATION. You may publish (or present papers or articles) on your results from using the Materials provided that no material or substantial portion of the Materials is included in any such publication or presentation.
|
29 |
-
|
30 |
-
6) FEEDBACK. Any feedback about the Materials provided by you to us is voluntarily given, and Microsoft shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the feedback is designated by you as confidential. Such feedback shall be considered a contribution and licensed to Microsoft under the terms of Section 4 above.
|
31 |
-
|
32 |
-
7) COMPLIANCE WITH TRADE LAWS. You acknowledge that the Materials may be subject to applicable trade laws in one or more countries. You will comply with all relevant laws and regulations applicable to the import or export of the Materials, including but not limited to, trade laws such as the U.S. Export Administration Regulations or other end-user, end use, and destination restrictions by the U.S. and other governments, as well as sanctions regulations administered by the U.S. Office of Foreign Assets Control. Microsoft may suspend or terminate the agreement immediately to the extent that Microsoft reasonably concludes that continued performance would violate trade laws or put it at risk of becoming subject to sanctions or penalties under trade laws. For additional information, see www.microsoft.com/exporting.
|
33 |
-
|
34 |
-
8) SUPPORT SERVICES. Microsoft is not obligated under this agreement to provide any support services for the Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
35 |
-
|
36 |
-
9) BINDING ARBITRATION AND CLASS ACTION WAIVER. This Section applies if you live in (or, if a business, your principal place of business is in) the United States. If you and Microsoft have a dispute, you and Microsoft agree to try for 60 days to resolve it informally. If you and Microsoft can’t, you and Microsoft agree to binding individual arbitration before the American Arbitration Association under the Federal Arbitration Act (“FAA”), and not to sue in court in front of a judge or jury. Instead, a neutral arbitrator will decide. Class action lawsuits, class-wide arbitrations, private attorney-general actions, and any other proceeding where someone acts in a representative capacity are not allowed; nor is combining individual proceedings without the consent of all parties. The complete Arbitration Agreement contains more terms and is at aka.ms/arb-agreement-1. You and Microsoft agree to these terms.
|
37 |
-
|
38 |
-
10) ENTIRE AGREEMENT. This agreement, and any other terms Microsoft may provide for supplements, updates, or third-party applications, is the entire agreement for the Materials.
|
39 |
-
|
40 |
-
11) APPLICABLE LAW AND PLACE TO RESOLVE DISPUTES. If you acquired the Materials in the United States or Canada, the laws of the state or province where you live (or, if a business, where your principal place of business is located) govern the interpretation of this agreement, claims for its breach, and all other claims (including consumer protection, unfair competition, and tort claims), regardless of conflict of laws principles, except that the FAA governs everything related to arbitration. If you acquired the Materials in any other country, its laws apply, except that the FAA governs everything related to arbitration. If U.S. federal jurisdiction exists, you and Microsoft consent to exclusive jurisdiction and venue in the federal court in King County, Washington for all disputes heard in court (excluding arbitration). If not, you and Microsoft consent to exclusive jurisdiction and venue in the Superior Court of King County, Washington for all disputes heard in court (excluding arbitration).
|
41 |
-
|
42 |
-
12) CONSUMER RIGHTS; REGIONAL VARIATIONS. This agreement describes certain legal rights. You may have other rights, including consumer rights, under the laws of your state, province, or country. Separate and apart from your relationship with Microsoft, you may also have rights with respect to the party from which you acquired the Materials. This agreement does not change those other rights if the laws of your state, province, or country do not permit it to do so. For example, if you acquired the Materials in one of the below regions, or mandatory country law applies, then the following provisions apply to you:
|
43 |
-
|
44 |
-
a) Australia. You have statutory guarantees under the Australian Consumer Law and nothing in this agreement is intended to affect those rights.
|
45 |
-
|
46 |
-
b) Canada. If you acquired this software in Canada, you may stop receiving updates by turning off the automatic update feature, disconnecting your device from the Internet (if and when you re-connect to the Internet, however, the Materials will resume checking for and installing updates), or uninstalling the Materials. The product documentation, if any, may also specify how to turn off updates for your specific device or software.
|
47 |
-
|
48 |
-
c) Germany and Austria.
|
49 |
-
|
50 |
-
i. Warranty. The properly licensed software will perform substantially as described in any Microsoft materials that accompany the Materials. However, Microsoft gives no contractual guarantee in relation to the licensed software.
|
51 |
-
|
52 |
-
ii. Limitation of Liability. In case of intentional conduct, gross negligence, claims based on the Product Liability Act, as well as, in case of death or personal or physical injury, Microsoft is liable according to the statutory law.
|
53 |
-
|
54 |
-
Subject to the foregoing clause (ii), Microsoft will only be liable for slight negligence if Microsoft is in breach of such material contractual obligations, the fulfillment of which facilitate the due performance of this agreement, the breach of which would endanger the purpose of this agreement and the compliance with which a party may constantly trust in (so-called "cardinal obligations"). In other cases of slight negligence, Microsoft will not be liable for slight negligence.
|
55 |
-
|
56 |
-
13) DISCLAIMER OF WARRANTY. THE MATERIALS ARE LICENSED “AS IS.” YOU BEAR THE RISK OF USING THEM. MICROSOFT GIVES NO EXPRESS WARRANTIES, GUARANTEES, OR CONDITIONS. TO THE EXTENT PERMITTED UNDER APPLICABLE LAWS, MICROSOFT EXCLUDES ALL IMPLIED WARRANTIES, INCLUDING MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT.
|
57 |
-
|
58 |
-
14) LIMITATION ON AND EXCLUSION OF DAMAGES. IF YOU HAVE ANY BASIS FOR RECOVERING DAMAGES DESPITE THE PRECEDING DISCLAIMER OF WARRANTY, YOU CAN RECOVER FROM MICROSOFT AND ITS SUPPLIERS ONLY DIRECT DAMAGES UP TO U.S. $5.00. YOU CANNOT RECOVER ANY OTHER DAMAGES, INCLUDING CONSEQUENTIAL, LOST PROFITS, SPECIAL, INDIRECT OR INCIDENTAL DAMAGES.
|
59 |
-
|
60 |
-
This limitation applies to (a) anything related to the Materials, services, content (including code) on third party Internet sites, or third party applications; and (b) claims for breach of contract, warranty, guarantee, or condition; strict liability, negligence, or other tort; or any other claim; in each case to the extent permitted by applicable law.
|
61 |
-
|
62 |
-
It also applies even if Microsoft knew or should have known about the possibility of the damages. The above limitation or exclusion may not apply to you because your state, province, or country may not allow the exclusion or limitation of incidental, consequential, or other damages.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/README.md
DELETED
@@ -1,260 +0,0 @@
|
|
1 |
-
# LLaVA-Med: Large Language and Vision Assistant for Biomedicine
|
2 |
-
|
3 |
-
*Visual instruction tuning towards building large language and vision models with GPT-4 level capabilities in the biomedicine space.*
|
4 |
-
|
5 |
-
[[Paper, NeurIPS 2023 Datasets and Benchmarks Track (Spotlight)](https://arxiv.org/abs/2306.00890)]
|
6 |
-
|
7 |
-
**LLaVA-Med: Training a Large Language-and-Vision Assistant for Biomedicine in One Day** <br>
|
8 |
-
|
9 |
-
[Chunyuan Li*](https://chunyuan.li/), [Cliff Wong*](https://scholar.google.com/citations?user=Sl05ifcAAAAJ&hl=en), [Sheng Zhang*](https://scholar.google.com/citations?user=-LVEXQ8AAAAJ&hl=en), [Naoto Usuyama](https://www.microsoft.com/en-us/research/people/naotous/), [Haotian Liu](https://hliu.cc), [Jianwei Yang](https://jwyang.github.io/), [Tristan Naumann](https://scholar.google.com/citations?user=cjlSeqwAAAAJ&hl=en), [Hoifung Poon](https://scholar.google.com/citations?user=yqqmVbkAAAAJ&hl=en), [Jianfeng Gao](https://scholar.google.com/citations?user=CQ1cqKkAAAAJ&hl=en) (*Equal Contribution)
|
10 |
-
|
11 |
-
<p align="center">
|
12 |
-
<img src="images/llava_med_logo.png" width="50%"> <br>
|
13 |
-
|
14 |
-
*Generated by <a href="https://gligen.github.io/">GLIGEN</a> using the grounded inpainting mode, with three boxes: ``white doctor coat``, ``stethoscope``, ``white doctor hat with a red cross sign``.*
|
15 |
-
|
16 |
-
</p>
|
17 |
-
|
18 |
-
|
19 |
-
## Release
|
20 |
-
|
21 |
-
- [May 13, 2024] 🔥LLaVA-Med v1.5 is out! It is not only significantly better (see the [evaluation results](docs/llava_med_performance.md#llava-med-15-performance).) but also much easier to use: no more *delta* weights! Now you can directly load our model from the [🤗 Hub](https://huggingface.co/microsoft/llava-med-v1.5-mistral-7b). The original LLaVA-Med (i.e., v1.0.0) codebase has been moved to [Archive](#archive).
|
22 |
-
- [Nov 8, 2023] LLaVA-Med is open-sourced under the MSR release policy. Huge thanks to commitment of the team, and patience of the community.
|
23 |
-
- [Sept, 2023] LLaVA-Med is accepted in NeurIPS 2023 Datasets and Benchmarks Track, as a spotlight presentation.
|
24 |
-
- [June 1, 2023] 🔥 We released **LLaVA-Med: Large Language and Vision Assistant for Biomedicine**, a step towards building biomedical domain large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2306.00890)
|
25 |
-
|
26 |
-
<p align="center">
|
27 |
-
<img src="images/llava_med_pipeline.png" width="90%"> <br>
|
28 |
-
|
29 |
-
*LLaVA-Med was initialized with the general-domain LLaVA and then continuously trained in a curriculum learning fashion (first biomedical concept alignment then full-blown instruction-tuning). We evaluated LLaVA-Med on standard visual conversation and question answering tasks.*
|
30 |
-
</p>
|
31 |
-
|
32 |
-
[![Code License](https://img.shields.io/badge/Code%20License-Microsoft%20Research-red)](Research%20License.docx)
|
33 |
-
[![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](https://creativecommons.org/licenses/by-nc/4.0/deed.en)
|
34 |
-
**Usage and License Notices**: The data, code, and model checkpoints are intended and licensed for research use only. They are also subject to additional restrictions dictated by the Terms of Use: LLaMA, Vicuna and GPT-4 respectively. The data is made available under CC BY NC 4.0. The data, code, and model checkpoints may be used for non-commercial purposes and any models trained using the dataset should be used only for research purposes. It is expressly prohibited for models trained on this data to be used in clinical care or for any clinical decision making purposes.
|
35 |
-
|
36 |
-
## Contents
|
37 |
-
|
38 |
-
- [Install](#install)
|
39 |
-
- [Model Download](#model-download)
|
40 |
-
- [Serving](#serving)
|
41 |
-
- [Evaluation](#evaluation)
|
42 |
-
- [Data Download](#data-download)
|
43 |
-
- [Archive](#archive)
|
44 |
-
- [Model Description](#model-description)
|
45 |
-
|
46 |
-
## Install
|
47 |
-
|
48 |
-
1. Clone this repository and navigate to LLaVA-Med folder
|
49 |
-
```bash
|
50 |
-
https://github.com/microsoft/LLaVA-Med.git
|
51 |
-
cd LLaVA-Med
|
52 |
-
```
|
53 |
-
|
54 |
-
2. Install Package: Create conda environment
|
55 |
-
|
56 |
-
```Shell
|
57 |
-
conda create -n llava-med python=3.10 -y
|
58 |
-
conda activate llava-med
|
59 |
-
pip install --upgrade pip # enable PEP 660 support
|
60 |
-
pip install -e .
|
61 |
-
```
|
62 |
-
|
63 |
-
## Model Download
|
64 |
-
|
65 |
-
|
66 |
-
Model Descriptions | 🤗 Huggingface Hub |
|
67 |
-
| --- | ---: |
|
68 |
-
| LLaVA-Med v1.5 | [microsoft/llava-med-v1.5-mistral-7b](https://huggingface.co/microsoft/llava-med-v1.5-mistral-7b) |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
## Serving
|
73 |
-
|
74 |
-
### Web UI
|
75 |
-
|
76 |
-
#### Launch a controller
|
77 |
-
```Shell
|
78 |
-
python -m llava.serve.controller --host 0.0.0.0 --port 10000
|
79 |
-
```
|
80 |
-
|
81 |
-
#### Launch a model worker
|
82 |
-
```Shell
|
83 |
-
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path microsoft/llava-med-v1.5-mistral-7b --multi-modal
|
84 |
-
```
|
85 |
-
Wait until the process finishes loading the model and you see "Uvicorn running on ...".
|
86 |
-
|
87 |
-
#### Launch a model worker (Multiple GPUs, when GPU VRAM <= 24GB)
|
88 |
-
|
89 |
-
If your the VRAM of your GPU is less than 24GB (e.g., RTX 3090, RTX 4090, etc.), you may try running it with multiple GPUs.
|
90 |
-
|
91 |
-
```Shell
|
92 |
-
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path microsoft/llava-med-v1.5-mistral-7b --multi-modal --num-gpus 2
|
93 |
-
```
|
94 |
-
Wait until the process finishes loading the model and you see "Uvicorn running on ...".
|
95 |
-
|
96 |
-
|
97 |
-
#### Send a test message
|
98 |
-
```Shell
|
99 |
-
python -m llava.serve.test_message --model-name llava-med-v1.5-mistral-7b --controller http://localhost:10000
|
100 |
-
```
|
101 |
-
|
102 |
-
#### Launch a gradio web server.
|
103 |
-
```Shell
|
104 |
-
python -m llava.serve.gradio_web_server --controller http://localhost:10000
|
105 |
-
```
|
106 |
-
#### You can open your browser and chat with a model now.
|
107 |
-
|
108 |
-
|
109 |
-
## Evaluation
|
110 |
-
|
111 |
-
### Medical Visual Chat (GPT-assisted Evaluation)
|
112 |
-
|
113 |
-
Our GPT-assisted evaluation pipeline for multimodal modeling is provided for a comprehensive understanding of the capabilities of vision-language models. Please see our paper for more details.
|
114 |
-
|
115 |
-
#### 1. Azure OpenAI Connection Info.
|
116 |
-
|
117 |
-
Open [llava/eval/llm.py](llava/eval/llm.py?plain=1#L33) and insert your Azure OpenAI Endpoint and API KEY
|
118 |
-
```Shell
|
119 |
-
openai_cxn_dict = {
|
120 |
-
'default': {
|
121 |
-
'endpoint': "INSERT YOUR AZURE OPENAI ENDPOINT HERE",
|
122 |
-
'api_key': "INSERT YOUR AZURE OPENAI API KEY HERE",
|
123 |
-
},
|
124 |
-
}
|
125 |
-
```
|
126 |
-
* GPT-4 inference was only tested using Azure OpenAI API. If you are using OpenAI API, you need to replace [llava/eval/llm.py (line 55)](llava/eval/llm.py?plain=1#L55) AsyncAzureOpenAI with AsyncOpenAI.
|
127 |
-
|
128 |
-
#### 2. Deployment ID
|
129 |
-
In [llava/eval/eval_multimodal_chat_gpt_score.py (line 55)](llava/eval/eval_multimodal_chat_gpt_score.py?plain=1#L55), replace with your GPT-4 model deployment id if necessary:
|
130 |
-
|
131 |
-
#### 3. Download Images
|
132 |
-
|
133 |
-
```Shell
|
134 |
-
wget https://hanoverprod.z21.web.core.windows.net/med_llava/multimodal_chat_eval/llava_med_test_image_urls.jsonl -P data/
|
135 |
-
python llava/data/download_images.py \
|
136 |
-
--input_path data/llava_med_test_image_urls.jsonl \
|
137 |
-
--pmc_output_path data/pmc \
|
138 |
-
--images_output_path data/images
|
139 |
-
```
|
140 |
-
|
141 |
-
#### 4. Multimodal Chat Inference
|
142 |
-
In our case, [`llava_med_eval_qa50_qa.jsonl`](/data/eval/llava_med_eval_qa50_qa.jsonl) contains the questions, context (captions and inline-mentions) and responses generated by text-only GPT-4 (0314), which we treat as ground truth.
|
143 |
-
|
144 |
-
```Shell
|
145 |
-
PYTHONPATH=. python llava/eval/model_vqa.py \
|
146 |
-
--conv-mode mistral_instruct \
|
147 |
-
--model-path microsoft/llava-med-v1.5-mistral-7b \
|
148 |
-
--question-file data/eval/llava_med_eval_qa50_qa.jsonl \
|
149 |
-
--image-folder data/images \
|
150 |
-
--answers-file /path/to/answer-file.jsonl \
|
151 |
-
--temperature 0.0
|
152 |
-
```
|
153 |
-
|
154 |
-
#### 5. GPT-4 Evaluation of the Generated Answers
|
155 |
-
|
156 |
-
```Shell
|
157 |
-
python llava/eval/eval_multimodal_chat_gpt_score.py \
|
158 |
-
--answers-file /path/to/answer-file.jsonl \
|
159 |
-
--question-file data/eval/llava_med_eval_qa50_qa.jsonl \
|
160 |
-
--scores-file /path/to/scores-file.jsonl
|
161 |
-
```
|
162 |
-
|
163 |
-
#### 6. Summarize the Evaluation Results
|
164 |
-
|
165 |
-
```Shell
|
166 |
-
python llava/eval/summarize_gpt_review.py \
|
167 |
-
--scores-file /path/to/scores-file.jsonl
|
168 |
-
```
|
169 |
-
|
170 |
-
## Data Download
|
171 |
-
|
172 |
-
### LLaVA-Med Dataset
|
173 |
-
|
174 |
-
<p align="center">
|
175 |
-
<img src="images/llava_med_dataset.png" width="90%"> <br>
|
176 |
-
|
177 |
-
*The data statistics of biomedical multimodal instruction-following data: (a,b) The root verb-noun pairs of instruction and responses, where the inner circle of the plot represents the root verb of the output response, and the outer circle represents the direct nouns. (c) The distribution of images and QA pairs on the five domains, one image is shown per domain.*
|
178 |
-
</p>
|
179 |
-
|
180 |
-
### Data Download
|
181 |
-
| Alignment data files | Size |
|
182 |
-
| --- | ---: |
|
183 |
-
| [llava_med_alignment_500k.json](https://hanoverprod.z21.web.core.windows.net/med_llava/alignment/llava_med_alignment_500k.json) | 341.52 MiB |
|
184 |
-
|
185 |
-
| Instruction-Tuning data files | Size |
|
186 |
-
| --- | ---: |
|
187 |
-
| [llava_med_instruct_10k.json](https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_10k.json) | 19.24 MiB |
|
188 |
-
| [llava_med_instruct_60k.json](https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_60k.json) | 84.65 MiB |
|
189 |
-
| [llava_med_instruct_60k_inline_mention.json](https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_60k_inline_mention.json) | 83.61 MiB |
|
190 |
-
| [llava_med_instruct_fig_captions.json](https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_fig_captions.json) | 161.39 MiB |
|
191 |
-
|
192 |
-
| Evaluation files | Size |
|
193 |
-
| --- | ---: |
|
194 |
-
| [llava_med_eval_qa50_qa.jsonl](https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_eval_qa50_qa.jsonl) | 256.18 KiB |
|
195 |
-
| [llava_med_eval_qa50_fig_captions.json](https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_eval_qa50_fig_captions.json) | 51.82 KiB |
|
196 |
-
| [llava_med_qa50_instruct_caption_in_text_cleaned-60k-3epoch.json](https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_qa50_instruct_caption_in_text_cleaned-60k-3epoch.json) | 100.97 KiB |
|
197 |
-
|
198 |
-
| Image URLS | Size |
|
199 |
-
| --- | ---: |
|
200 |
-
| [llava_med_image_urls.jsonl](https://hanoverprod.z21.web.core.windows.net/med_llava/llava_med_image_urls.jsonl) | 122.82 MiB |
|
201 |
-
|
202 |
-
[download_images.py](https://github.com/microsoft/LLaVA-Med/blob/v1.0.0/llava/data/download_images.py) is used to download the PMC articles using the above image_urls file and extract the images
|
203 |
-
|
204 |
-
To download our langauge-image multimodal instruction-folllowing dataset, please run the following script:
|
205 |
-
```bash
|
206 |
-
sh download_data.sh
|
207 |
-
```
|
208 |
-
|
209 |
-
|
210 |
-
## Archive
|
211 |
-
|
212 |
-
- [LLaVA-Med v1.0](https://github.com/microsoft/LLaVA-Med/tree/v1.0.0)
|
213 |
-
|
214 |
-
## Model Description
|
215 |
-
|
216 |
-
Large Language and Vision Assistant for bioMedicine (i.e., “LLaVA-Med”) is a large language and vision model trained using a curriculum learning method for adapting LLaVA to the biomedical domain. It is an open-source release intended for research use only to facilitate reproducibility of the corresponding paper which claims improved performance for open-ended biomedical questions answering tasks, including common visual question answering (VQA) benchmark datasets such as PathVQA and VQA-RAD.
|
217 |
-
|
218 |
-
### Model Uses
|
219 |
-
|
220 |
-
#### Intended Use
|
221 |
-
|
222 |
-
The data, code, and model checkpoints are intended to be used solely for (I) future research on visual-language processing and (II) reproducibility of the experimental results reported in the reference paper. The data, code, and model checkpoints are not intended to be used in clinical care or for any clinical decision making purposes.
|
223 |
-
|
224 |
-
#### Primary Intended Use
|
225 |
-
|
226 |
-
The primary intended use is to support AI researchers reproducing and building on top of this work. LLaVA-Med and its associated models should be helpful for exploring various biomedical vision-language processing (VLP ) and vision question answering (VQA) research questions.
|
227 |
-
|
228 |
-
#### Out-of-Scope Use
|
229 |
-
|
230 |
-
**Any** deployed use case of the model --- commercial or otherwise --- is out of scope. Although we evaluated the models using a broad set of publicly-available research benchmarks, the models and evaluations are intended *for research use only* and not intended for deployed use cases. Please refer to [the associated paper](https://aka.ms/llava-med) for more details.
|
231 |
-
|
232 |
-
### Data
|
233 |
-
|
234 |
-
This model builds upon [PMC-15M dataset](https://aka.ms/biomedclip-paper), which is a large-scale parallel image-text dataset for biomedical vision-language processing. It contains 15 million figure-caption pairs extracted from biomedical research articles in PubMed Central. It covers a diverse range of biomedical image types, such as microscopy, radiography, histology, and more.
|
235 |
-
|
236 |
-
### Limitations
|
237 |
-
|
238 |
-
This model was developed using English corpora, and thus may be considered English-only. This model is evaluated on a narrow set of biomedical benchmark tasks, described in [LLaVA-Med paper](https://aka.ms/llava-med). As such, it is not suitable for use in any clinical setting. Under some conditions, the model may make inaccurate predictions and display limitations, which may require additional mitigation strategies. In particular, this model is likely to carry many of the limitations of the model from which it is derived, [LLaVA](https://llava-vl.github.io/).
|
239 |
-
|
240 |
-
Further, this model was developed in part using the [PMC-15M](https://aka.ms/biomedclip-paper) dataset. The figure-caption pairs that make up this dataset may contain biases reflecting the current practice of academic publication. For example, the corresponding papers may be enriched for positive findings, contain examples of extreme cases, and otherwise reflect distributions that are not representative of other sources of biomedical data.
|
241 |
-
|
242 |
-
## Acknowledgement
|
243 |
-
|
244 |
-
If you find LLaVA-Med useful for your your research and applications, please cite using this BibTeX:
|
245 |
-
|
246 |
-
```bibtex
|
247 |
-
@article{li2023llavamed,
|
248 |
-
title={Llava-med: Training a large language-and-vision assistant for biomedicine in one day},
|
249 |
-
author={Li, Chunyuan and Wong, Cliff and Zhang, Sheng and Usuyama, Naoto and Liu, Haotian and Yang, Jianwei and Naumann, Tristan and Poon, Hoifung and Gao, Jianfeng},
|
250 |
-
journal={arXiv preprint arXiv:2306.00890},
|
251 |
-
year={2023}
|
252 |
-
}
|
253 |
-
```
|
254 |
-
|
255 |
-
|
256 |
-
## Related Projects
|
257 |
-
|
258 |
-
- [LLaVA](https://llava-vl.github.io/)
|
259 |
-
- [BiomedCLIP](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224)
|
260 |
-
- [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/SECURITY.md
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
|
2 |
-
|
3 |
-
## Security
|
4 |
-
|
5 |
-
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
6 |
-
|
7 |
-
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
|
8 |
-
|
9 |
-
## Reporting Security Issues
|
10 |
-
|
11 |
-
**Please do not report security vulnerabilities through public GitHub issues.**
|
12 |
-
|
13 |
-
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
|
14 |
-
|
15 |
-
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
|
16 |
-
|
17 |
-
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
18 |
-
|
19 |
-
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
20 |
-
|
21 |
-
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
22 |
-
* Full paths of source file(s) related to the manifestation of the issue
|
23 |
-
* The location of the affected source code (tag/branch/commit or direct URL)
|
24 |
-
* Any special configuration required to reproduce the issue
|
25 |
-
* Step-by-step instructions to reproduce the issue
|
26 |
-
* Proof-of-concept or exploit code (if possible)
|
27 |
-
* Impact of the issue, including how an attacker might exploit the issue
|
28 |
-
|
29 |
-
This information will help us triage your report more quickly.
|
30 |
-
|
31 |
-
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
|
32 |
-
|
33 |
-
## Preferred Languages
|
34 |
-
|
35 |
-
We prefer all communications to be in English.
|
36 |
-
|
37 |
-
## Policy
|
38 |
-
|
39 |
-
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
|
40 |
-
|
41 |
-
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/SUPPORT.md
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
# TODO: The maintainer of this repo has not yet edited this file
|
2 |
-
|
3 |
-
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
|
4 |
-
|
5 |
-
- **No CSS support:** Fill out this template with information about how to file issues and get help.
|
6 |
-
- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
|
7 |
-
- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
|
8 |
-
|
9 |
-
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
|
10 |
-
|
11 |
-
# Support
|
12 |
-
|
13 |
-
## How to file issues and get help
|
14 |
-
|
15 |
-
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
16 |
-
issues before filing new issues to avoid duplicates. For new issues, file your bug or
|
17 |
-
feature request as a new Issue.
|
18 |
-
|
19 |
-
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
|
20 |
-
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
|
21 |
-
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
|
22 |
-
|
23 |
-
## Microsoft Support Policy
|
24 |
-
|
25 |
-
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:0f0323de1ff1fdf8383e79bdad1283516a4c05a6fd2b44a363bf4e059422305b
|
3 |
-
size 69084267
|
|
|
|
|
|
|
|
LLaVA-Med/data/eval/llava_med_eval_qa50_qa.jsonl
DELETED
The diff for this file is too large to render.
See raw diff
|
|
LLaVA-Med/docs/llava_med_performance.md
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
## LLaVA-Med-1.5 Performance
|
2 |
-
|
3 |
-
<p align="center">
|
4 |
-
<img src="https://hanoverprod.z21.web.core.windows.net/med_llava/web/llava-med_1.5_eval.png" width="90%"> <br>
|
5 |
-
|
6 |
-
*Performance comparison of mulitmodal chat instruction-following abilities, measured by the relative score via language GPT-4 evaluation.*
|
7 |
-
</p>
|
8 |
-
|
9 |
-
|
10 |
-
## LLaVA-Med-1.0 Performance
|
11 |
-
|
12 |
-
<p align="center">
|
13 |
-
<img src="../images/llava_med_chat_example1.png" width="90%"> <br>
|
14 |
-
|
15 |
-
*Example 1: comparison of medical visual chat. The language-only GPT-4 is considered as the performance upper bound, as the golden captions and inline mentions are fed into GPT-4 as the context, without requiring the model to understand the raw image.*
|
16 |
-
</p>
|
17 |
-
|
18 |
-
<p align="center">
|
19 |
-
<img src="../images/llava_med_chat_example2.png" width="90%"> <br>
|
20 |
-
|
21 |
-
*Example 2: comparison of medical visual chat. LLaVA tends to halluciate or refuse to provide domain-specific knowledgable response.*
|
22 |
-
</p>
|
23 |
-
|
24 |
-
|
25 |
-
<p align="center">
|
26 |
-
<img src="../images/llava_med_vqa.png" width="90%"> <br>
|
27 |
-
|
28 |
-
*Performance comparison of fine-tuned LLaVA-Med on established Medical QVA datasets.*
|
29 |
-
</p>
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/download_data.sh
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
|
3 |
-
mkdir data/alignment
|
4 |
-
cd data/alignment
|
5 |
-
|
6 |
-
wget https://hanoverprod.z21.web.core.windows.net/med_llava/alignment/llava_med_alignment_500k.json
|
7 |
-
|
8 |
-
cd ..
|
9 |
-
|
10 |
-
mkdir instruct
|
11 |
-
cd instruct
|
12 |
-
|
13 |
-
wget https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_10k.json
|
14 |
-
wget https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_60k.json
|
15 |
-
wget https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_60k_inline_mention.json
|
16 |
-
wget https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_fig_captions.json
|
17 |
-
cd ..
|
18 |
-
|
19 |
-
mkdir eval
|
20 |
-
cd eval
|
21 |
-
|
22 |
-
wget https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_eval_qa50_qa.jsonl
|
23 |
-
wget https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_eval_qa50_fig_captions.json
|
24 |
-
wget https://hanoverprod.z21.web.core.windows.net/med_llava/eval/llava_med_qa50_instruct_caption_in_text_cleaned-60k-3epoch.json
|
25 |
-
|
26 |
-
cd ..
|
27 |
-
|
28 |
-
wget https://hanoverprod.z21.web.core.windows.net/med_llava/llava_med_image_urls.jsonl
|
29 |
-
mkdir pmc_articles
|
30 |
-
mkdir images
|
31 |
-
|
32 |
-
cd ..
|
33 |
-
|
34 |
-
pip install tqdm
|
35 |
-
python llava/data/download_images.py --input_path data/llava_med_image_urls.jsonl --pmc_output_path data/pmc_articles/ --images_output_path data/images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/images/llava_logo.png
DELETED
Binary file (268 kB)
|
|
LLaVA-Med/images/llava_med_chat.png
DELETED
Binary file (63.1 kB)
|
|
LLaVA-Med/images/llava_med_chat_example1.png
DELETED
Binary file (315 kB)
|
|
LLaVA-Med/images/llava_med_chat_example2.png
DELETED
Binary file (279 kB)
|
|
LLaVA-Med/images/llava_med_dataset.png
DELETED
Binary file (542 kB)
|
|
LLaVA-Med/images/llava_med_logo.png
DELETED
Binary file (379 kB)
|
|
LLaVA-Med/images/llava_med_pipeline.png
DELETED
Binary file (349 kB)
|
|
LLaVA-Med/images/llava_med_vqa.png
DELETED
Binary file (126 kB)
|
|
LLaVA-Med/llava/__init__.py
DELETED
File without changes
|
LLaVA-Med/llava/constants.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
-
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
-
|
4 |
-
LOGDIR = "."
|
5 |
-
|
6 |
-
# Model Constants
|
7 |
-
IGNORE_INDEX = -100
|
8 |
-
IMAGE_TOKEN_INDEX = -200
|
9 |
-
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
-
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
-
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
-
DEFAULT_IM_END_TOKEN = "<im_end>"
|
13 |
-
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/conversation.py
DELETED
@@ -1,439 +0,0 @@
|
|
1 |
-
import dataclasses
|
2 |
-
from enum import auto, Enum
|
3 |
-
from typing import List, Tuple
|
4 |
-
import base64
|
5 |
-
from io import BytesIO
|
6 |
-
from PIL import Image
|
7 |
-
|
8 |
-
|
9 |
-
class SeparatorStyle(Enum):
|
10 |
-
"""Different separator style."""
|
11 |
-
SINGLE = auto()
|
12 |
-
TWO = auto()
|
13 |
-
MPT = auto()
|
14 |
-
PLAIN = auto()
|
15 |
-
LLAMA_2 = auto()
|
16 |
-
MISTRAL = auto()
|
17 |
-
|
18 |
-
|
19 |
-
@dataclasses.dataclass
|
20 |
-
class Conversation:
|
21 |
-
"""A class that keeps all conversation history."""
|
22 |
-
system: str
|
23 |
-
roles: List[str]
|
24 |
-
messages: List[List[str]]
|
25 |
-
offset: int
|
26 |
-
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
27 |
-
sep: str = "###"
|
28 |
-
sep2: str = None
|
29 |
-
version: str = "Unknown"
|
30 |
-
|
31 |
-
skip_next: bool = False
|
32 |
-
|
33 |
-
def get_prompt(self):
|
34 |
-
messages = self.messages
|
35 |
-
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
36 |
-
messages = self.messages.copy()
|
37 |
-
init_role, init_msg = messages[0].copy()
|
38 |
-
init_msg = init_msg[0].replace("<image>", "").strip()
|
39 |
-
if 'mmtag' in self.version:
|
40 |
-
messages[0] = (init_role, init_msg)
|
41 |
-
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
42 |
-
messages.insert(1, (self.roles[1], "Received."))
|
43 |
-
else:
|
44 |
-
messages[0] = (init_role, "<image>\n" + init_msg)
|
45 |
-
|
46 |
-
if self.sep_style == SeparatorStyle.SINGLE:
|
47 |
-
ret = self.system + self.sep
|
48 |
-
for role, message in messages:
|
49 |
-
if message:
|
50 |
-
if type(message) is tuple:
|
51 |
-
message, _, _ = message
|
52 |
-
ret += role + ": " + message + self.sep
|
53 |
-
else:
|
54 |
-
ret += role + ":"
|
55 |
-
elif self.sep_style == SeparatorStyle.TWO:
|
56 |
-
seps = [self.sep, self.sep2]
|
57 |
-
ret = self.system + seps[0]
|
58 |
-
for i, (role, message) in enumerate(messages):
|
59 |
-
if message:
|
60 |
-
if type(message) is tuple:
|
61 |
-
message, _, _ = message
|
62 |
-
sep = seps[i % 2]
|
63 |
-
sep = "{0} ".format(self.sep2) if sep == self.sep2 else self.sep
|
64 |
-
ret += role + ": " + message.strip() + sep
|
65 |
-
else:
|
66 |
-
ret += role + ":"
|
67 |
-
ret = ret.strip()
|
68 |
-
elif self.sep_style == SeparatorStyle.MPT:
|
69 |
-
ret = self.system + self.sep
|
70 |
-
for role, message in messages:
|
71 |
-
if message:
|
72 |
-
if type(message) is tuple:
|
73 |
-
message, _, _ = message
|
74 |
-
ret += role + message + self.sep
|
75 |
-
else:
|
76 |
-
ret += role
|
77 |
-
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
78 |
-
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
79 |
-
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
80 |
-
ret = ""
|
81 |
-
|
82 |
-
for i, (role, message) in enumerate(messages):
|
83 |
-
if i == 0:
|
84 |
-
assert message, "first message should not be none"
|
85 |
-
assert role == self.roles[0], "first message should come from user"
|
86 |
-
if message:
|
87 |
-
if type(message) is tuple:
|
88 |
-
message, _, _ = message
|
89 |
-
if i == 0: message = wrap_sys(self.system) + message
|
90 |
-
if i % 2 == 0:
|
91 |
-
message = wrap_inst(message)
|
92 |
-
ret += self.sep + message
|
93 |
-
else:
|
94 |
-
ret += " " + message + " " + self.sep2
|
95 |
-
else:
|
96 |
-
ret += ""
|
97 |
-
ret = ret.lstrip(self.sep)
|
98 |
-
elif self.sep_style == SeparatorStyle.PLAIN:
|
99 |
-
seps = [self.sep, self.sep2]
|
100 |
-
ret = self.system
|
101 |
-
for i, (role, message) in enumerate(messages):
|
102 |
-
if message:
|
103 |
-
if type(message) is tuple:
|
104 |
-
message, _, _ = message
|
105 |
-
ret += message + seps[i % 2]
|
106 |
-
else:
|
107 |
-
ret += ""
|
108 |
-
elif self.sep_style == SeparatorStyle.MISTRAL:
|
109 |
-
# reference: https://docs.mistral.ai/models/
|
110 |
-
wrap_sys = lambda msg: f"{msg}</s>"
|
111 |
-
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
112 |
-
ret = ""
|
113 |
-
for i, (role, message) in enumerate(messages):
|
114 |
-
if i == 0:
|
115 |
-
assert message, "first message should not be none"
|
116 |
-
assert role == self.roles[0], "first message should come from user"
|
117 |
-
if message:
|
118 |
-
if type(message) is tuple:
|
119 |
-
message, _, _ = message
|
120 |
-
if i == 0: message = self.system + " " + message.strip()
|
121 |
-
if i % 2 == 0:
|
122 |
-
message = wrap_inst(message)
|
123 |
-
ret += message
|
124 |
-
else:
|
125 |
-
ret += wrap_sys(message)
|
126 |
-
else:
|
127 |
-
ret += ""
|
128 |
-
# wrap_sys = lambda msg: f"\n{msg}\n\n"
|
129 |
-
# wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
130 |
-
# ret = ""
|
131 |
-
# for i, (role, message) in enumerate(messages):
|
132 |
-
# if i == 0:
|
133 |
-
# assert message, "first message should not be none"
|
134 |
-
# assert role == self.roles[0], "first message should come from user"
|
135 |
-
# if message:
|
136 |
-
# if type(message) is tuple:
|
137 |
-
# message, _, _ = message
|
138 |
-
# if i == 0: message = wrap_sys(self.system) + message
|
139 |
-
# if i % 2 == 0:
|
140 |
-
# message = wrap_inst(message)
|
141 |
-
# ret += message if i != 0 else self.sep + message
|
142 |
-
# else:
|
143 |
-
# # NOTE-JW: we need to add " " to strictly follow Mistral Instruction Format
|
144 |
-
# ret += " " + message + " " + self.sep2
|
145 |
-
# # ret += " " + wrap_sys(message)
|
146 |
-
# else:
|
147 |
-
# ret += ""
|
148 |
-
else:
|
149 |
-
raise ValueError(f"Invalid style: {self.sep_style}")
|
150 |
-
|
151 |
-
return ret
|
152 |
-
|
153 |
-
def append_message(self, role, message):
|
154 |
-
self.messages.append([role, message])
|
155 |
-
|
156 |
-
def get_images(self, return_pil=False):
|
157 |
-
images = []
|
158 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
159 |
-
if i % 2 == 0:
|
160 |
-
if type(msg) is tuple:
|
161 |
-
import base64
|
162 |
-
from io import BytesIO
|
163 |
-
from PIL import Image
|
164 |
-
msg, image, image_process_mode = msg
|
165 |
-
if image_process_mode == "Pad":
|
166 |
-
def expand2square(pil_img, background_color=(122, 116, 104)):
|
167 |
-
width, height = pil_img.size
|
168 |
-
if width == height:
|
169 |
-
return pil_img
|
170 |
-
elif width > height:
|
171 |
-
result = Image.new(pil_img.mode, (width, width), background_color)
|
172 |
-
result.paste(pil_img, (0, (width - height) // 2))
|
173 |
-
return result
|
174 |
-
else:
|
175 |
-
result = Image.new(pil_img.mode, (height, height), background_color)
|
176 |
-
result.paste(pil_img, ((height - width) // 2, 0))
|
177 |
-
return result
|
178 |
-
image = expand2square(image)
|
179 |
-
elif image_process_mode in ["Default", "Crop"]:
|
180 |
-
pass
|
181 |
-
elif image_process_mode == "Resize":
|
182 |
-
image = image.resize((336, 336))
|
183 |
-
else:
|
184 |
-
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
185 |
-
max_hw, min_hw = max(image.size), min(image.size)
|
186 |
-
aspect_ratio = max_hw / min_hw
|
187 |
-
max_len, min_len = 800, 400
|
188 |
-
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
189 |
-
longest_edge = int(shortest_edge * aspect_ratio)
|
190 |
-
W, H = image.size
|
191 |
-
if longest_edge != max(image.size):
|
192 |
-
if H > W:
|
193 |
-
H, W = longest_edge, shortest_edge
|
194 |
-
else:
|
195 |
-
H, W = shortest_edge, longest_edge
|
196 |
-
image = image.resize((W, H))
|
197 |
-
if return_pil:
|
198 |
-
images.append(image)
|
199 |
-
else:
|
200 |
-
buffered = BytesIO()
|
201 |
-
image.save(buffered, format="PNG")
|
202 |
-
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
203 |
-
images.append(img_b64_str)
|
204 |
-
return images
|
205 |
-
|
206 |
-
def to_gradio_chatbot(self):
|
207 |
-
ret = []
|
208 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
209 |
-
if i % 2 == 0:
|
210 |
-
if type(msg) is tuple:
|
211 |
-
import base64
|
212 |
-
from io import BytesIO
|
213 |
-
msg, image, image_process_mode = msg
|
214 |
-
max_hw, min_hw = max(image.size), min(image.size)
|
215 |
-
aspect_ratio = max_hw / min_hw
|
216 |
-
max_len, min_len = 800, 400
|
217 |
-
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
218 |
-
longest_edge = int(shortest_edge * aspect_ratio)
|
219 |
-
W, H = image.size
|
220 |
-
if H > W:
|
221 |
-
H, W = longest_edge, shortest_edge
|
222 |
-
else:
|
223 |
-
H, W = shortest_edge, longest_edge
|
224 |
-
image = image.resize((W, H))
|
225 |
-
buffered = BytesIO()
|
226 |
-
image.save(buffered, format="JPEG")
|
227 |
-
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
228 |
-
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
229 |
-
msg = img_str + msg.replace('<image>', '').strip()
|
230 |
-
ret.append([msg, None])
|
231 |
-
else:
|
232 |
-
ret.append([msg, None])
|
233 |
-
else:
|
234 |
-
ret[-1][-1] = msg
|
235 |
-
return ret
|
236 |
-
|
237 |
-
def copy(self):
|
238 |
-
return Conversation(
|
239 |
-
system=self.system,
|
240 |
-
roles=self.roles,
|
241 |
-
messages=[[x, y] for x, y in self.messages],
|
242 |
-
offset=self.offset,
|
243 |
-
sep_style=self.sep_style,
|
244 |
-
sep=self.sep,
|
245 |
-
sep2=self.sep2,
|
246 |
-
version=self.version)
|
247 |
-
|
248 |
-
def dict(self):
|
249 |
-
if len(self.get_images()) > 0:
|
250 |
-
return {
|
251 |
-
"system": self.system,
|
252 |
-
"roles": self.roles,
|
253 |
-
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
254 |
-
"offset": self.offset,
|
255 |
-
"sep": self.sep,
|
256 |
-
"sep2": self.sep2,
|
257 |
-
}
|
258 |
-
return {
|
259 |
-
"system": self.system,
|
260 |
-
"roles": self.roles,
|
261 |
-
"messages": self.messages,
|
262 |
-
"offset": self.offset,
|
263 |
-
"sep": self.sep,
|
264 |
-
"sep2": self.sep2,
|
265 |
-
}
|
266 |
-
|
267 |
-
|
268 |
-
conv_vicuna_v0 = Conversation(
|
269 |
-
system="A chat between a curious human and an artificial intelligence assistant. "
|
270 |
-
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
271 |
-
roles=("Human", "Assistant"),
|
272 |
-
messages=(
|
273 |
-
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
274 |
-
("Assistant",
|
275 |
-
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
276 |
-
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
277 |
-
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
278 |
-
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
279 |
-
"renewable and non-renewable energy sources:\n"
|
280 |
-
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
281 |
-
"energy sources are finite and will eventually run out.\n"
|
282 |
-
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
283 |
-
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
284 |
-
"and other negative effects.\n"
|
285 |
-
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
286 |
-
"have lower operational costs than non-renewable sources.\n"
|
287 |
-
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
288 |
-
"locations than non-renewable sources.\n"
|
289 |
-
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
290 |
-
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
291 |
-
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
292 |
-
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
293 |
-
),
|
294 |
-
offset=2,
|
295 |
-
sep_style=SeparatorStyle.SINGLE,
|
296 |
-
sep="###",
|
297 |
-
)
|
298 |
-
|
299 |
-
conv_vicuna_v1 = Conversation(
|
300 |
-
system="A chat between a curious user and an artificial intelligence assistant. "
|
301 |
-
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
302 |
-
roles=("USER", "ASSISTANT"),
|
303 |
-
version="v1",
|
304 |
-
messages=(),
|
305 |
-
offset=0,
|
306 |
-
sep_style=SeparatorStyle.TWO,
|
307 |
-
sep=" ",
|
308 |
-
sep2="</s>",
|
309 |
-
)
|
310 |
-
|
311 |
-
conv_llama_2 = Conversation(
|
312 |
-
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
313 |
-
|
314 |
-
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
315 |
-
roles=("USER", "ASSISTANT"),
|
316 |
-
version="llama_v2",
|
317 |
-
messages=(),
|
318 |
-
offset=0,
|
319 |
-
sep_style=SeparatorStyle.LLAMA_2,
|
320 |
-
sep="<s>",
|
321 |
-
sep2="</s>",
|
322 |
-
)
|
323 |
-
|
324 |
-
conv_llava_llama_2 = Conversation(
|
325 |
-
system="You are a helpful language and vision assistant. "
|
326 |
-
"You are able to understand the visual content that the user provides, "
|
327 |
-
"and assist the user with a variety of tasks using natural language.",
|
328 |
-
roles=("USER", "ASSISTANT"),
|
329 |
-
version="llama_v2",
|
330 |
-
messages=(),
|
331 |
-
offset=0,
|
332 |
-
sep_style=SeparatorStyle.LLAMA_2,
|
333 |
-
sep="<s>",
|
334 |
-
sep2="</s>",
|
335 |
-
)
|
336 |
-
|
337 |
-
conv_mpt = Conversation(
|
338 |
-
system="""<|im_start|>system
|
339 |
-
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
340 |
-
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
341 |
-
version="mpt",
|
342 |
-
messages=(),
|
343 |
-
offset=0,
|
344 |
-
sep_style=SeparatorStyle.MPT,
|
345 |
-
sep="<|im_end|>",
|
346 |
-
)
|
347 |
-
|
348 |
-
conv_llava_plain = Conversation(
|
349 |
-
system="",
|
350 |
-
roles=("", ""),
|
351 |
-
messages=(
|
352 |
-
),
|
353 |
-
offset=0,
|
354 |
-
sep_style=SeparatorStyle.PLAIN,
|
355 |
-
sep="\n",
|
356 |
-
)
|
357 |
-
|
358 |
-
conv_llava_v0 = Conversation(
|
359 |
-
system="A chat between a curious human and an artificial intelligence assistant. "
|
360 |
-
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
361 |
-
roles=("Human", "Assistant"),
|
362 |
-
messages=(
|
363 |
-
),
|
364 |
-
offset=0,
|
365 |
-
sep_style=SeparatorStyle.SINGLE,
|
366 |
-
sep="###",
|
367 |
-
)
|
368 |
-
|
369 |
-
conv_llava_v0_mmtag = Conversation(
|
370 |
-
system="A chat between a curious user and an artificial intelligence assistant. "
|
371 |
-
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
372 |
-
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
373 |
-
roles=("Human", "Assistant"),
|
374 |
-
messages=(
|
375 |
-
),
|
376 |
-
offset=0,
|
377 |
-
sep_style=SeparatorStyle.SINGLE,
|
378 |
-
sep="###",
|
379 |
-
version="v0_mmtag",
|
380 |
-
)
|
381 |
-
|
382 |
-
conv_llava_v1 = Conversation(
|
383 |
-
system="A chat between a curious human and an artificial intelligence assistant. "
|
384 |
-
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
385 |
-
roles=("USER", "ASSISTANT"),
|
386 |
-
version="v1",
|
387 |
-
messages=(),
|
388 |
-
offset=0,
|
389 |
-
sep_style=SeparatorStyle.TWO,
|
390 |
-
sep=" ",
|
391 |
-
sep2="</s>",
|
392 |
-
)
|
393 |
-
|
394 |
-
conv_llava_v1_mmtag = Conversation(
|
395 |
-
system="A chat between a curious user and an artificial intelligence assistant. "
|
396 |
-
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
397 |
-
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
398 |
-
roles=("USER", "ASSISTANT"),
|
399 |
-
messages=(),
|
400 |
-
offset=0,
|
401 |
-
sep_style=SeparatorStyle.TWO,
|
402 |
-
sep=" ",
|
403 |
-
sep2="</s>",
|
404 |
-
version="v1_mmtag",
|
405 |
-
)
|
406 |
-
|
407 |
-
conv_mistral_instruct = Conversation(
|
408 |
-
system="",
|
409 |
-
roles=("USER", "ASSISTANT"),
|
410 |
-
version="llama_v2",
|
411 |
-
messages=(),
|
412 |
-
offset=0,
|
413 |
-
sep_style=SeparatorStyle.LLAMA_2,
|
414 |
-
sep="",
|
415 |
-
sep2="</s>",
|
416 |
-
)
|
417 |
-
|
418 |
-
default_conversation = conv_vicuna_v1
|
419 |
-
conv_templates = {
|
420 |
-
"default": conv_vicuna_v0,
|
421 |
-
"v0": conv_vicuna_v0,
|
422 |
-
"v1": conv_vicuna_v1,
|
423 |
-
"vicuna_v1": conv_vicuna_v1,
|
424 |
-
"llama_2": conv_llama_2,
|
425 |
-
"mistral_instruct": conv_mistral_instruct,
|
426 |
-
|
427 |
-
"plain": conv_llava_plain,
|
428 |
-
"v0_plain": conv_llava_plain,
|
429 |
-
"llava_v0": conv_llava_v0,
|
430 |
-
"v0_mmtag": conv_llava_v0_mmtag,
|
431 |
-
"llava_v1": conv_llava_v1,
|
432 |
-
"v1_mmtag": conv_llava_v1_mmtag,
|
433 |
-
"llava_llama_2": conv_llava_llama_2,
|
434 |
-
"mpt": conv_mpt,
|
435 |
-
}
|
436 |
-
|
437 |
-
|
438 |
-
if __name__ == "__main__":
|
439 |
-
print(default_conversation.get_prompt())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/eval/eval_multimodal_chat_gpt_score.py
DELETED
@@ -1,112 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import argparse
|
4 |
-
from copy import deepcopy
|
5 |
-
import itertools
|
6 |
-
from typing import Any
|
7 |
-
from operator import add
|
8 |
-
from pprint import pprint
|
9 |
-
from typing import List
|
10 |
-
from pathlib import Path
|
11 |
-
from tqdm import tqdm
|
12 |
-
|
13 |
-
import llm
|
14 |
-
import util
|
15 |
-
|
16 |
-
|
17 |
-
INSTRUCT_PROMPT = """We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with caption describing the same image.
|
18 |
-
Please rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.
|
19 |
-
Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."""
|
20 |
-
ROLE = 'Assistant'
|
21 |
-
|
22 |
-
# Generate instruction for GPT-4 to score the two answers.
|
23 |
-
def conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2):
|
24 |
-
return (f'[Context]\n'
|
25 |
-
f'Figure Caption:\n{fig_label}: {fig_caption}\n\n'
|
26 |
-
f'Figure Context:\n\t- {fig_context}\n\n'
|
27 |
-
f'[Question]\n{question}\n\n'
|
28 |
-
f'[{ROLE} 1]\n{ans1}\n\n[End of {ROLE} 1]\n\n'
|
29 |
-
f'[{ROLE} 2]\n{ans2}\n\n[End of {ROLE} 2]\n\n'
|
30 |
-
f'[System]\n{INSTRUCT_PROMPT}\n\n')
|
31 |
-
|
32 |
-
def compare_messages_gen(fig_label, fig_caption, fig_context, question, ans1, ans2):
|
33 |
-
messages = [
|
34 |
-
{"role": "system", "content": """'You are a helpful and precise assistant for checking the quality of the answer."""},
|
35 |
-
]
|
36 |
-
messages.append({"role": "user", "content": conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2)})
|
37 |
-
return messages
|
38 |
-
|
39 |
-
|
40 |
-
def sum_list_list(x):
|
41 |
-
return sum(item for inner_list in x for item in inner_list)
|
42 |
-
|
43 |
-
def chunk(lst, n):
|
44 |
-
for i in range(0, len(lst), n):
|
45 |
-
if i+(1.5*n)<len(lst):
|
46 |
-
end = i + n
|
47 |
-
else:
|
48 |
-
end = len(lst)
|
49 |
-
yield lst[i:end]
|
50 |
-
if end==len(lst):
|
51 |
-
return
|
52 |
-
|
53 |
-
|
54 |
-
def infer(samples):
|
55 |
-
model_inst = llm.GPT("gpt-4-0314")
|
56 |
-
|
57 |
-
BATCH_SIZE = 1
|
58 |
-
batch_samples = []
|
59 |
-
results = []
|
60 |
-
batch = []
|
61 |
-
|
62 |
-
print('Starting Multimodal Chat GPT Scoring Eval')
|
63 |
-
|
64 |
-
for sample in tqdm(samples):
|
65 |
-
sample_copy = deepcopy(sample)
|
66 |
-
input_msg = compare_messages_gen(sample_copy['fig_label'], sample_copy['fig_caption'], sample_copy['in_text_mention'], sample_copy['question'], sample_copy['ans1'], sample_copy['ans2'])
|
67 |
-
batch.append(input_msg)
|
68 |
-
batch_samples.append(sample_copy)
|
69 |
-
if len(batch)>=BATCH_SIZE:
|
70 |
-
inference_results = [x.strip() for chunk_messages in chunk([x for x in batch if x], BATCH_SIZE) for x in model_inst.infer(chunk_messages)]
|
71 |
-
for item, inference_result in zip(batch_samples, inference_results):
|
72 |
-
item['gpt_eval'] = inference_result
|
73 |
-
results.extend(batch_samples)
|
74 |
-
batch = []
|
75 |
-
batch_samples = []
|
76 |
-
inference_results = [x.strip() for chunk_messages in chunk([x for x in batch if x], BATCH_SIZE) for x in model_inst.infer(chunk_messages)]
|
77 |
-
for item, inference_result in zip(batch_samples, inference_results):
|
78 |
-
item['gpt_eval'] = inference_result
|
79 |
-
results.extend(batch_samples)
|
80 |
-
print(f"Result Size: {len(results)}")
|
81 |
-
return results
|
82 |
-
|
83 |
-
|
84 |
-
def main(args):
|
85 |
-
answer_data = util.load_file_jsonl(args.answers_file)
|
86 |
-
question_data = util.load_file_jsonl(args.question_file)
|
87 |
-
|
88 |
-
samples = []
|
89 |
-
for question, answer in zip(question_data, answer_data):
|
90 |
-
question_copy = deepcopy(question)
|
91 |
-
question['question'] = question_copy['text']
|
92 |
-
question['ans1'] = question_copy.pop('gpt4_answer')
|
93 |
-
question['ans2'] = answer['text']
|
94 |
-
samples.append(question)
|
95 |
-
|
96 |
-
results = infer(samples)
|
97 |
-
|
98 |
-
# Create parent directory of output score files if it doesn't exist
|
99 |
-
os.makedirs(Path(args.scores_file).parent, exist_ok=True)
|
100 |
-
|
101 |
-
with open(args.scores_file, 'w') as f:
|
102 |
-
for row in results:
|
103 |
-
f.write(json.dumps(row)+'\n')
|
104 |
-
|
105 |
-
|
106 |
-
if __name__ == '__main__':
|
107 |
-
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Scoring", add_help=True)
|
108 |
-
parser.add_argument("--answers-file", default="", metavar="FILE", help="path to model answer file")
|
109 |
-
parser.add_argument("--question-file", default="data/questions/llava_med_eval_qa50_qa.jsonl", metavar="FILE", help="path to multichat questions file")
|
110 |
-
parser.add_argument("--scores-file", default="", metavar="FILE", help="path to save gpt-4 score file")
|
111 |
-
args = parser.parse_args()
|
112 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/eval/llm.py
DELETED
@@ -1,134 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import abc
|
3 |
-
import asyncio
|
4 |
-
from abc import abstractmethod
|
5 |
-
import math
|
6 |
-
|
7 |
-
import tiktoken
|
8 |
-
import openai
|
9 |
-
import backoff
|
10 |
-
|
11 |
-
|
12 |
-
class LLM(abc.ABC):
|
13 |
-
|
14 |
-
prompt_percent = 0.9
|
15 |
-
|
16 |
-
@abstractmethod
|
17 |
-
def __init__(self):
|
18 |
-
raise NotImplementedError("Subclasses should implement this!")
|
19 |
-
|
20 |
-
@abstractmethod
|
21 |
-
def infer(self, prompts):
|
22 |
-
raise NotImplementedError("Subclasses should implement this!")
|
23 |
-
|
24 |
-
@abstractmethod
|
25 |
-
def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
|
26 |
-
raise NotImplementedError("Subclasses should implement this!")
|
27 |
-
|
28 |
-
|
29 |
-
class GPT(LLM):
|
30 |
-
|
31 |
-
prompt_percent = 0.8
|
32 |
-
|
33 |
-
openai_cxn_dict = {
|
34 |
-
'default': {
|
35 |
-
'endpoint': "INSERT YOUR AZURE OPENAI ENDPOINT HERE",
|
36 |
-
'api_key': "INSERT YOUR AZURE OPENAI API KEY HERE",
|
37 |
-
},
|
38 |
-
}
|
39 |
-
|
40 |
-
deployment_max_length_dict = {
|
41 |
-
'gpt-4': 8192,
|
42 |
-
'gpt-4-0314': 8192,
|
43 |
-
'gpt-4-32k': 32768,
|
44 |
-
'gpt-35-turbo': 4096,
|
45 |
-
'gpt-35-turbo-16k': 16385,
|
46 |
-
}
|
47 |
-
|
48 |
-
def __init__(self, model_id):
|
49 |
-
self.temperature = 0.0
|
50 |
-
self.top_k = 1
|
51 |
-
self.encoding = tiktoken.encoding_for_model("-".join(model_id.split("-", 2)[:2]).replace('5', '.5'))
|
52 |
-
self.openai_api = 'default'
|
53 |
-
self.model_id = model_id
|
54 |
-
self.max_length = self.deployment_max_length_dict[model_id]
|
55 |
-
self.client = openai.AsyncAzureOpenAI(
|
56 |
-
api_key=self.openai_cxn_dict[self.openai_api]['api_key'],
|
57 |
-
api_version="2023-12-01-preview",
|
58 |
-
azure_endpoint=self.openai_cxn_dict[self.openai_api]['endpoint']
|
59 |
-
)
|
60 |
-
|
61 |
-
def gen_messages(self, fixed_instruction, few_shot_examples, input, input_header, output_header):
|
62 |
-
messages = [
|
63 |
-
{
|
64 |
-
"role": "system",
|
65 |
-
"content": fixed_instruction,
|
66 |
-
},
|
67 |
-
]
|
68 |
-
for example in few_shot_examples:
|
69 |
-
messages.extend(
|
70 |
-
[
|
71 |
-
{
|
72 |
-
"role": "user",
|
73 |
-
"content": input_header+'\n'+example['user']+'\n\n'+output_header,
|
74 |
-
},
|
75 |
-
{
|
76 |
-
"role": "assistant",
|
77 |
-
"content": example['assistant'],
|
78 |
-
},
|
79 |
-
]
|
80 |
-
)
|
81 |
-
messages.extend(
|
82 |
-
[
|
83 |
-
{
|
84 |
-
"role": "user",
|
85 |
-
"content": input_header+'\n'+input+'\n\n'+output_header,
|
86 |
-
},
|
87 |
-
]
|
88 |
-
)
|
89 |
-
return messages
|
90 |
-
|
91 |
-
# Define the coroutine for making API calls to GPT
|
92 |
-
@backoff.on_exception(backoff.expo, openai.RateLimitError)
|
93 |
-
async def make_api_call_to_gpt(
|
94 |
-
self,
|
95 |
-
messages
|
96 |
-
):
|
97 |
-
response = await self.client.chat.completions.create(
|
98 |
-
model=self.model_id,
|
99 |
-
messages=messages,
|
100 |
-
temperature=self.temperature,
|
101 |
-
)
|
102 |
-
return response.choices[0].message.content
|
103 |
-
|
104 |
-
async def dispatch_openai_requests(
|
105 |
-
self,
|
106 |
-
messages_list,
|
107 |
-
):
|
108 |
-
# Asynchronously call the function for each prompt
|
109 |
-
tasks = [self.make_api_call_to_gpt(messages) for messages in messages_list]
|
110 |
-
|
111 |
-
# Gather and run the tasks concurrently
|
112 |
-
results = await asyncio.gather(*tasks)
|
113 |
-
return results
|
114 |
-
|
115 |
-
def infer(self,
|
116 |
-
messages_list,
|
117 |
-
):
|
118 |
-
return asyncio.run(self.dispatch_openai_requests(messages_list))
|
119 |
-
|
120 |
-
def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
|
121 |
-
# Tokenize fixed_prompt
|
122 |
-
fixed_token_ids = self.encoding.encode(fixed_instruction+' '.join([x['user']+' '+x['assistant'] for x in few_shot_examples]))
|
123 |
-
# Calculate remaining token length
|
124 |
-
remaining_token_len = math.ceil((self.prompt_percent*self.max_length)-len(fixed_token_ids))
|
125 |
-
|
126 |
-
# Tokenize splittable_input
|
127 |
-
split_token_ids = self.encoding.encode(splittable_input)
|
128 |
-
|
129 |
-
# Split tokenized split_prompt into list of individual inputs strings. Uses tokens to calculate length
|
130 |
-
split_token_ids_list = [split_token_ids[i:i+remaining_token_len+10] for i in range(0, len(split_token_ids), remaining_token_len)]
|
131 |
-
split_input_list = [self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list]
|
132 |
-
|
133 |
-
# Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
|
134 |
-
return [self.gen_messages(fixed_instruction, few_shot_examples, split_input, input_header, output_header) for split_input in split_input_list]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/eval/model_vqa.py
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import torch
|
3 |
-
import os
|
4 |
-
import json
|
5 |
-
from tqdm import tqdm
|
6 |
-
import shortuuid
|
7 |
-
|
8 |
-
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
-
from llava.conversation import conv_templates, SeparatorStyle
|
10 |
-
from llava.model.builder import load_pretrained_model
|
11 |
-
from llava.utils import disable_torch_init
|
12 |
-
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images
|
13 |
-
|
14 |
-
from PIL import Image
|
15 |
-
import math
|
16 |
-
from transformers import set_seed, logging
|
17 |
-
|
18 |
-
logging.set_verbosity_error()
|
19 |
-
|
20 |
-
|
21 |
-
def split_list(lst, n):
|
22 |
-
"""Split a list into n (roughly) equal-sized chunks"""
|
23 |
-
chunk_size = math.ceil(len(lst) / n) # integer division
|
24 |
-
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
25 |
-
|
26 |
-
|
27 |
-
def get_chunk(lst, n, k):
|
28 |
-
chunks = split_list(lst, n)
|
29 |
-
return chunks[k]
|
30 |
-
|
31 |
-
|
32 |
-
def eval_model(args):
|
33 |
-
set_seed(0)
|
34 |
-
# Model
|
35 |
-
disable_torch_init()
|
36 |
-
model_path = os.path.expanduser(args.model_path)
|
37 |
-
model_name = get_model_name_from_path(model_path)
|
38 |
-
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
|
39 |
-
|
40 |
-
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
41 |
-
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
42 |
-
answers_file = os.path.expanduser(args.answers_file)
|
43 |
-
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
44 |
-
ans_file = open(answers_file, "w")
|
45 |
-
for line in tqdm(questions):
|
46 |
-
idx = line["question_id"]
|
47 |
-
image_file = line["image"]
|
48 |
-
qs = line["text"].replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
49 |
-
cur_prompt = qs
|
50 |
-
if model.config.mm_use_im_start_end:
|
51 |
-
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
52 |
-
else:
|
53 |
-
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
54 |
-
|
55 |
-
conv = conv_templates[args.conv_mode].copy()
|
56 |
-
conv.append_message(conv.roles[0], qs)
|
57 |
-
conv.append_message(conv.roles[1], None)
|
58 |
-
prompt = conv.get_prompt()
|
59 |
-
|
60 |
-
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
61 |
-
|
62 |
-
image = Image.open(os.path.join(args.image_folder, image_file))
|
63 |
-
image_tensor = process_images([image], image_processor, model.config)[0]
|
64 |
-
|
65 |
-
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
66 |
-
keywords = [stop_str]
|
67 |
-
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
68 |
-
|
69 |
-
with torch.inference_mode():
|
70 |
-
output_ids = model.generate(
|
71 |
-
input_ids,
|
72 |
-
images=image_tensor.unsqueeze(0).half().cuda(),
|
73 |
-
do_sample=True if args.temperature > 0 else False,
|
74 |
-
temperature=args.temperature,
|
75 |
-
top_p=args.top_p,
|
76 |
-
num_beams=args.num_beams,
|
77 |
-
# no_repeat_ngram_size=3,
|
78 |
-
max_new_tokens=1024,
|
79 |
-
use_cache=True)
|
80 |
-
|
81 |
-
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
82 |
-
|
83 |
-
ans_id = shortuuid.uuid()
|
84 |
-
ans_file.write(json.dumps({"question_id": idx,
|
85 |
-
"prompt": cur_prompt,
|
86 |
-
"text": outputs,
|
87 |
-
"answer_id": ans_id,
|
88 |
-
"model_id": model_name,
|
89 |
-
"metadata": {}}) + "\n")
|
90 |
-
ans_file.flush()
|
91 |
-
ans_file.close()
|
92 |
-
|
93 |
-
|
94 |
-
if __name__ == "__main__":
|
95 |
-
parser = argparse.ArgumentParser()
|
96 |
-
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
97 |
-
parser.add_argument("--model-base", type=str, default=None)
|
98 |
-
parser.add_argument("--image-folder", type=str, default="")
|
99 |
-
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
100 |
-
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
101 |
-
parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
|
102 |
-
parser.add_argument("--num-chunks", type=int, default=1)
|
103 |
-
parser.add_argument("--chunk-idx", type=int, default=0)
|
104 |
-
parser.add_argument("--temperature", type=float, default=0.2)
|
105 |
-
parser.add_argument("--top_p", type=float, default=None)
|
106 |
-
parser.add_argument("--num_beams", type=int, default=1)
|
107 |
-
args = parser.parse_args()
|
108 |
-
|
109 |
-
eval_model(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/eval/run_llava.py
DELETED
@@ -1,145 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import torch
|
3 |
-
|
4 |
-
from llava.constants import (
|
5 |
-
IMAGE_TOKEN_INDEX,
|
6 |
-
DEFAULT_IMAGE_TOKEN,
|
7 |
-
DEFAULT_IM_START_TOKEN,
|
8 |
-
DEFAULT_IM_END_TOKEN,
|
9 |
-
IMAGE_PLACEHOLDER,
|
10 |
-
)
|
11 |
-
from llava.conversation import conv_templates, SeparatorStyle
|
12 |
-
from llava.model.builder import load_pretrained_model
|
13 |
-
from llava.utils import disable_torch_init
|
14 |
-
from llava.mm_utils import (
|
15 |
-
process_images,
|
16 |
-
tokenizer_image_token,
|
17 |
-
get_model_name_from_path,
|
18 |
-
)
|
19 |
-
|
20 |
-
from PIL import Image
|
21 |
-
|
22 |
-
import requests
|
23 |
-
from PIL import Image
|
24 |
-
from io import BytesIO
|
25 |
-
import re
|
26 |
-
|
27 |
-
|
28 |
-
def image_parser(args):
|
29 |
-
out = args.image_file.split(args.sep)
|
30 |
-
return out
|
31 |
-
|
32 |
-
|
33 |
-
def load_image(image_file):
|
34 |
-
if image_file.startswith("http") or image_file.startswith("https"):
|
35 |
-
response = requests.get(image_file)
|
36 |
-
image = Image.open(BytesIO(response.content)).convert("RGB")
|
37 |
-
else:
|
38 |
-
image = Image.open(image_file).convert("RGB")
|
39 |
-
return image
|
40 |
-
|
41 |
-
|
42 |
-
def load_images(image_files):
|
43 |
-
out = []
|
44 |
-
for image_file in image_files:
|
45 |
-
image = load_image(image_file)
|
46 |
-
out.append(image)
|
47 |
-
return out
|
48 |
-
|
49 |
-
|
50 |
-
def eval_model(args):
|
51 |
-
# Model
|
52 |
-
disable_torch_init()
|
53 |
-
|
54 |
-
model_name = get_model_name_from_path(args.model_path)
|
55 |
-
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
56 |
-
args.model_path, args.model_base, model_name
|
57 |
-
)
|
58 |
-
|
59 |
-
qs = args.query
|
60 |
-
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
61 |
-
if IMAGE_PLACEHOLDER in qs:
|
62 |
-
if model.config.mm_use_im_start_end:
|
63 |
-
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
|
64 |
-
else:
|
65 |
-
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
|
66 |
-
else:
|
67 |
-
if model.config.mm_use_im_start_end:
|
68 |
-
qs = image_token_se + "\n" + qs
|
69 |
-
else:
|
70 |
-
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
|
71 |
-
|
72 |
-
if "llama-2" in model_name.lower():
|
73 |
-
conv_mode = "llava_llama_2"
|
74 |
-
elif "mistral" in model_name.lower():
|
75 |
-
conv_mode = "mistral_instruct"
|
76 |
-
elif "v1.6-34b" in model_name.lower():
|
77 |
-
conv_mode = "chatml_direct"
|
78 |
-
elif "v1" in model_name.lower():
|
79 |
-
conv_mode = "llava_v1"
|
80 |
-
elif "mpt" in model_name.lower():
|
81 |
-
conv_mode = "mpt"
|
82 |
-
else:
|
83 |
-
conv_mode = "llava_v0"
|
84 |
-
|
85 |
-
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
86 |
-
print(
|
87 |
-
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
|
88 |
-
conv_mode, args.conv_mode, args.conv_mode
|
89 |
-
)
|
90 |
-
)
|
91 |
-
else:
|
92 |
-
args.conv_mode = conv_mode
|
93 |
-
|
94 |
-
conv = conv_templates[args.conv_mode].copy()
|
95 |
-
conv.append_message(conv.roles[0], qs)
|
96 |
-
conv.append_message(conv.roles[1], None)
|
97 |
-
prompt = conv.get_prompt()
|
98 |
-
|
99 |
-
image_files = image_parser(args)
|
100 |
-
images = load_images(image_files)
|
101 |
-
image_sizes = [x.size for x in images]
|
102 |
-
images_tensor = process_images(
|
103 |
-
images,
|
104 |
-
image_processor,
|
105 |
-
model.config
|
106 |
-
).to(model.device, dtype=torch.float16)
|
107 |
-
|
108 |
-
input_ids = (
|
109 |
-
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
110 |
-
.unsqueeze(0)
|
111 |
-
.cuda()
|
112 |
-
)
|
113 |
-
|
114 |
-
with torch.inference_mode():
|
115 |
-
output_ids = model.generate(
|
116 |
-
input_ids,
|
117 |
-
images=images_tensor,
|
118 |
-
image_sizes=image_sizes,
|
119 |
-
do_sample=True if args.temperature > 0 else False,
|
120 |
-
temperature=args.temperature,
|
121 |
-
top_p=args.top_p,
|
122 |
-
num_beams=args.num_beams,
|
123 |
-
max_new_tokens=args.max_new_tokens,
|
124 |
-
use_cache=True,
|
125 |
-
)
|
126 |
-
|
127 |
-
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
128 |
-
print(outputs)
|
129 |
-
|
130 |
-
|
131 |
-
if __name__ == "__main__":
|
132 |
-
parser = argparse.ArgumentParser()
|
133 |
-
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
134 |
-
parser.add_argument("--model-base", type=str, default=None)
|
135 |
-
parser.add_argument("--image-file", type=str, required=True)
|
136 |
-
parser.add_argument("--query", type=str, required=True)
|
137 |
-
parser.add_argument("--conv-mode", type=str, default=None)
|
138 |
-
parser.add_argument("--sep", type=str, default=",")
|
139 |
-
parser.add_argument("--temperature", type=float, default=0.2)
|
140 |
-
parser.add_argument("--top_p", type=float, default=None)
|
141 |
-
parser.add_argument("--num_beams", type=int, default=1)
|
142 |
-
parser.add_argument("--max_new_tokens", type=int, default=512)
|
143 |
-
args = parser.parse_args()
|
144 |
-
|
145 |
-
eval_model(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/eval/summarize_gpt_review.py
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
from copy import deepcopy
|
3 |
-
import util
|
4 |
-
from pprint import pprint
|
5 |
-
from collections import defaultdict
|
6 |
-
import pandas as pd
|
7 |
-
import json
|
8 |
-
|
9 |
-
|
10 |
-
def get_domain(x):
|
11 |
-
for domain in ['chest_xray', 'mri', 'histology', 'gross', 'ct_scan']:
|
12 |
-
in_domain = x['domain'][domain]
|
13 |
-
if in_domain:
|
14 |
-
return domain
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
def main(args):
|
19 |
-
scores_data = util.load_file_jsonl(args.scores_file)
|
20 |
-
predictions = [(x['question_id'], x['type'], get_domain(x), x['gpt_eval'].split('\n')[0].split(' ')) for x in scores_data]
|
21 |
-
|
22 |
-
score_type_dict = defaultdict(lambda: defaultdict(list))
|
23 |
-
for q_id, q_type, domain, (a1_score, a2_score) in predictions:
|
24 |
-
score_type_dict[q_type][1].append(a1_score)
|
25 |
-
score_type_dict[q_type][2].append(a2_score)
|
26 |
-
score_type_dict['overall'][1].append(a1_score)
|
27 |
-
score_type_dict['overall'][2].append(a2_score)
|
28 |
-
score_type_dict[domain][1].append(a1_score)
|
29 |
-
score_type_dict[domain][2].append(a2_score)
|
30 |
-
|
31 |
-
result = defaultdict(dict)
|
32 |
-
|
33 |
-
for q_type, score_dict in score_type_dict.items():
|
34 |
-
result[q_type]['gpt4_score'] = util.get_avg(score_dict[1])
|
35 |
-
result[q_type]['pred_score'] = util.get_avg(score_dict[2])
|
36 |
-
result[q_type]['pred_relative_score'] = util.get_avg([float(s2)/float(s1) for s1, s2 in zip(score_dict[1], score_dict[2])])*100
|
37 |
-
result[q_type]['data_size'] = len(score_dict[1])
|
38 |
-
|
39 |
-
df = pd.DataFrame.from_dict(result).filter(['conversation', 'detailed_description', 'chest_xray', 'mri', 'histology', 'gross', 'ct_scan', 'overall'])
|
40 |
-
print(df)
|
41 |
-
|
42 |
-
|
43 |
-
if __name__ == '__main__':
|
44 |
-
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Eval Postprocessing", add_help=True)
|
45 |
-
parser.add_argument("--scores-file", default="", metavar="FILE", help="input path to gpt-4 score file")
|
46 |
-
args = parser.parse_args()
|
47 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/eval/util.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
|
3 |
-
|
4 |
-
def load_file_jsonl(path):
|
5 |
-
with open(path) as f:
|
6 |
-
return [json.loads(row) for row in f]
|
7 |
-
|
8 |
-
def get_avg(x):
|
9 |
-
return sum([float(y) for y in x])/len(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/mm_utils.py
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
from PIL import Image
|
2 |
-
from io import BytesIO
|
3 |
-
import base64
|
4 |
-
import random
|
5 |
-
import torch
|
6 |
-
from transformers import StoppingCriteria
|
7 |
-
from llava.constants import IMAGE_TOKEN_INDEX
|
8 |
-
|
9 |
-
|
10 |
-
def load_image_from_base64(image):
|
11 |
-
return Image.open(BytesIO(base64.b64decode(image)))
|
12 |
-
|
13 |
-
|
14 |
-
def expand2square(pil_img, background_color):
|
15 |
-
width, height = pil_img.size
|
16 |
-
if width == height:
|
17 |
-
return pil_img
|
18 |
-
elif width > height:
|
19 |
-
result = Image.new(pil_img.mode, (width, width), background_color)
|
20 |
-
# sample a random between 0 and (width - height) // 2
|
21 |
-
y_start = random.randint((width - height) // 2, (width - height) // 2 + 1)
|
22 |
-
result.paste(pil_img, (0, y_start))
|
23 |
-
return result
|
24 |
-
else:
|
25 |
-
result = Image.new(pil_img.mode, (height, height), background_color)
|
26 |
-
# sample a random between 0 and (height - width) // 2
|
27 |
-
x_start = random.randint((height - width) // 2, (height - width) // 2 + 1)
|
28 |
-
result.paste(pil_img, (x_start, 0))
|
29 |
-
return result
|
30 |
-
|
31 |
-
|
32 |
-
def process_images(images, image_processor, model_cfg):
|
33 |
-
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
34 |
-
new_images = []
|
35 |
-
for image in images:
|
36 |
-
if image_aspect_ratio == 'pad':
|
37 |
-
if image.mode=='L':
|
38 |
-
background_color = int(255*sum(image_processor.image_mean)/len(image_processor.image_mean))
|
39 |
-
else:
|
40 |
-
background_color = tuple(int(x*255) for x in image_processor.image_mean)
|
41 |
-
image = expand2square(image, background_color)
|
42 |
-
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
43 |
-
new_images.append(image)
|
44 |
-
if all(x.shape == new_images[0].shape for x in new_images):
|
45 |
-
new_images = torch.stack(new_images, dim=0)
|
46 |
-
return new_images
|
47 |
-
|
48 |
-
|
49 |
-
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
50 |
-
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
51 |
-
|
52 |
-
def insert_separator(X, sep):
|
53 |
-
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
54 |
-
|
55 |
-
input_ids = []
|
56 |
-
offset = 0
|
57 |
-
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
58 |
-
offset = 1
|
59 |
-
input_ids.append(prompt_chunks[0][0])
|
60 |
-
|
61 |
-
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
62 |
-
input_ids.extend(x[offset:])
|
63 |
-
|
64 |
-
if return_tensors is not None:
|
65 |
-
if return_tensors == 'pt':
|
66 |
-
return torch.tensor(input_ids, dtype=torch.long)
|
67 |
-
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
68 |
-
return input_ids
|
69 |
-
|
70 |
-
|
71 |
-
def get_model_name_from_path(model_path):
|
72 |
-
model_path = model_path.strip("/")
|
73 |
-
model_paths = model_path.split("/")
|
74 |
-
if model_paths[-1].startswith('checkpoint-'):
|
75 |
-
return model_paths[-2] + "_" + model_paths[-1]
|
76 |
-
else:
|
77 |
-
return model_paths[-1]
|
78 |
-
|
79 |
-
class KeywordsStoppingCriteria(StoppingCriteria):
|
80 |
-
def __init__(self, keywords, tokenizer, input_ids):
|
81 |
-
self.keywords = keywords
|
82 |
-
self.keyword_ids = []
|
83 |
-
self.max_keyword_len = 0
|
84 |
-
for keyword in keywords:
|
85 |
-
cur_keyword_ids = tokenizer(keyword).input_ids
|
86 |
-
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
87 |
-
cur_keyword_ids = cur_keyword_ids[1:]
|
88 |
-
if len(cur_keyword_ids) > self.max_keyword_len:
|
89 |
-
self.max_keyword_len = len(cur_keyword_ids)
|
90 |
-
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
91 |
-
self.tokenizer = tokenizer
|
92 |
-
self.start_len = input_ids.shape[1]
|
93 |
-
|
94 |
-
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
95 |
-
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
96 |
-
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
97 |
-
for keyword_id in self.keyword_ids:
|
98 |
-
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
99 |
-
return True
|
100 |
-
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
101 |
-
for keyword in self.keywords:
|
102 |
-
if keyword in outputs:
|
103 |
-
return True
|
104 |
-
return False
|
105 |
-
|
106 |
-
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
107 |
-
outputs = []
|
108 |
-
for i in range(output_ids.shape[0]):
|
109 |
-
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
110 |
-
return all(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/model/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
|
|
|
|
LLaVA-Med/llava/model/builder.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
2 |
-
import torch
|
3 |
-
from llava.model import LlavaMistralForCausalLM
|
4 |
-
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
5 |
-
|
6 |
-
|
7 |
-
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
|
8 |
-
|
9 |
-
kwargs = {}
|
10 |
-
|
11 |
-
if device != "cuda":
|
12 |
-
kwargs['device_map'] = {"": device}
|
13 |
-
|
14 |
-
if load_8bit:
|
15 |
-
kwargs['load_in_8bit'] = True
|
16 |
-
elif load_4bit:
|
17 |
-
kwargs['load_in_4bit'] = True
|
18 |
-
kwargs['quantization_config'] = BitsAndBytesConfig(
|
19 |
-
load_in_4bit=True,
|
20 |
-
bnb_4bit_compute_dtype=torch.float16,
|
21 |
-
bnb_4bit_use_double_quant=True,
|
22 |
-
bnb_4bit_quant_type='nf4'
|
23 |
-
)
|
24 |
-
else:
|
25 |
-
kwargs['torch_dtype'] = torch.float16
|
26 |
-
|
27 |
-
if 'llava' in model_name.lower():
|
28 |
-
# Load LLaVA model
|
29 |
-
if 'mistral' in model_name.lower():
|
30 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
31 |
-
model = LlavaMistralForCausalLM.from_pretrained(
|
32 |
-
model_path,
|
33 |
-
low_cpu_mem_usage=False,
|
34 |
-
use_flash_attention_2=False,
|
35 |
-
**kwargs
|
36 |
-
)
|
37 |
-
else:
|
38 |
-
# Load language model
|
39 |
-
if model_base is not None:
|
40 |
-
# PEFT model
|
41 |
-
from peft import PeftModel
|
42 |
-
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
43 |
-
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
|
44 |
-
print(f"Loading LoRA weights from {model_path}")
|
45 |
-
model = PeftModel.from_pretrained(model, model_path)
|
46 |
-
print(f"Merging weights")
|
47 |
-
model = model.merge_and_unload()
|
48 |
-
print('Convert to FP16...')
|
49 |
-
model.to(torch.float16)
|
50 |
-
else:
|
51 |
-
use_fast = False
|
52 |
-
if 'mpt' in model_name.lower():
|
53 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
54 |
-
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
|
55 |
-
else:
|
56 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
57 |
-
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
58 |
-
|
59 |
-
image_processor = None
|
60 |
-
|
61 |
-
if 'llava' in model_name.lower(): # or 'mistral' in model_name.lower():
|
62 |
-
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
63 |
-
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
64 |
-
if mm_use_im_patch_token:
|
65 |
-
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
66 |
-
if mm_use_im_start_end:
|
67 |
-
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
68 |
-
model.resize_token_embeddings(len(tokenizer))
|
69 |
-
|
70 |
-
vision_tower = model.get_vision_tower()
|
71 |
-
if not vision_tower.is_loaded:
|
72 |
-
vision_tower.load_model()
|
73 |
-
vision_tower.to(device=device, dtype=torch.float16)
|
74 |
-
model.model.mm_projector.to(device=device, dtype=torch.float16)
|
75 |
-
model.to(device=device, dtype=torch.float16)
|
76 |
-
image_processor = vision_tower.image_processor
|
77 |
-
|
78 |
-
if hasattr(model.config, "max_sequence_length"):
|
79 |
-
context_len = model.config.max_sequence_length
|
80 |
-
else:
|
81 |
-
context_len = 2048
|
82 |
-
|
83 |
-
return tokenizer, model, image_processor, context_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/model/builders.py
DELETED
@@ -1,152 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import warnings
|
3 |
-
import shutil
|
4 |
-
|
5 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
6 |
-
import torch
|
7 |
-
from llava.model import LLavaMistralForCausalLM
|
8 |
-
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
-
|
10 |
-
|
11 |
-
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
|
12 |
-
kwargs = {"device_map": device_map, **kwargs}
|
13 |
-
|
14 |
-
if device != "cuda":
|
15 |
-
kwargs['device_map'] = {"": device}
|
16 |
-
|
17 |
-
if load_8bit:
|
18 |
-
kwargs['load_in_8bit'] = True
|
19 |
-
elif load_4bit:
|
20 |
-
kwargs['load_in_4bit'] = True
|
21 |
-
kwargs['quantization_config'] = BitsAndBytesConfig(
|
22 |
-
load_in_4bit=True,
|
23 |
-
bnb_4bit_compute_dtype=torch.float16,
|
24 |
-
bnb_4bit_use_double_quant=True,
|
25 |
-
bnb_4bit_quant_type='nf4'
|
26 |
-
)
|
27 |
-
else:
|
28 |
-
kwargs['torch_dtype'] = torch.float16
|
29 |
-
|
30 |
-
if use_flash_attn:
|
31 |
-
kwargs['attn_implementation'] = 'flash_attention_2'
|
32 |
-
|
33 |
-
if 'llava' in model_name.lower():
|
34 |
-
# Load LLaVA model
|
35 |
-
if 'lora' in model_name.lower() and model_base is None:
|
36 |
-
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
|
37 |
-
if 'lora' in model_name.lower() and model_base is not None:
|
38 |
-
from llava.model.language_model.llava_mistral import LlavaMistralConfig
|
39 |
-
lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
|
40 |
-
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
41 |
-
print('Loading LLaVA from base model...')
|
42 |
-
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
|
43 |
-
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
44 |
-
if model.lm_head.weight.shape[0] != token_num:
|
45 |
-
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
46 |
-
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
47 |
-
|
48 |
-
# print('Loading additional LLaVA weights...')
|
49 |
-
# if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
50 |
-
# non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
51 |
-
# else:
|
52 |
-
# # this is probably from HF Hub
|
53 |
-
# from huggingface_hub import hf_hub_download
|
54 |
-
# def load_from_hf(repo_id, filename, subfolder=None):
|
55 |
-
# cache_file = hf_hub_download(
|
56 |
-
# repo_id=repo_id,
|
57 |
-
# filename=filename,
|
58 |
-
# subfolder=subfolder)
|
59 |
-
# return torch.load(cache_file, map_location='cpu')
|
60 |
-
# non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
61 |
-
# non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
62 |
-
# if any(k.startswith('model.model.') for k in non_lora_trainables):
|
63 |
-
# non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
64 |
-
# model.load_state_dict(non_lora_trainables, strict=False)
|
65 |
-
|
66 |
-
from peft import PeftModel
|
67 |
-
print('Loading LoRA weights...')
|
68 |
-
model = PeftModel.from_pretrained(model, model_path)
|
69 |
-
print('Merging LoRA weights...')
|
70 |
-
model = model.merge_and_unload()
|
71 |
-
print('Model is loaded...')
|
72 |
-
elif model_base is not None:
|
73 |
-
# this may be mm projector only
|
74 |
-
print('Loading LLaVA from base model...')
|
75 |
-
if 'mpt' in model_name.lower():
|
76 |
-
if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
|
77 |
-
shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
|
78 |
-
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
|
79 |
-
cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
80 |
-
model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
81 |
-
else:
|
82 |
-
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
83 |
-
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
84 |
-
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
85 |
-
|
86 |
-
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
87 |
-
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
88 |
-
model.load_state_dict(mm_projector_weights, strict=False)
|
89 |
-
else:
|
90 |
-
if 'mpt' in model_name.lower():
|
91 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
92 |
-
model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
93 |
-
elif 'mistral' in model_name.lower():
|
94 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
95 |
-
model = LlavaMistralForCausalLM.from_pretrained(
|
96 |
-
model_path,
|
97 |
-
low_cpu_mem_usage=True,
|
98 |
-
**kwargs
|
99 |
-
)
|
100 |
-
else:
|
101 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
102 |
-
model = LlavaLlamaForCausalLM.from_pretrained(
|
103 |
-
model_path,
|
104 |
-
low_cpu_mem_usage=True,
|
105 |
-
**kwargs
|
106 |
-
)
|
107 |
-
else:
|
108 |
-
# Load language model
|
109 |
-
if model_base is not None:
|
110 |
-
# PEFT model
|
111 |
-
from peft import PeftModel
|
112 |
-
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
113 |
-
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
|
114 |
-
print(f"Loading LoRA weights from {model_path}")
|
115 |
-
model = PeftModel.from_pretrained(model, model_path)
|
116 |
-
print(f"Merging weights")
|
117 |
-
model = model.merge_and_unload()
|
118 |
-
print('Convert to FP16...')
|
119 |
-
model.to(torch.float16)
|
120 |
-
else:
|
121 |
-
use_fast = False
|
122 |
-
if 'mpt' in model_name.lower():
|
123 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
124 |
-
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
|
125 |
-
else:
|
126 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
127 |
-
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
128 |
-
|
129 |
-
image_processor = None
|
130 |
-
|
131 |
-
if 'mistral' in model_name.lower():
|
132 |
-
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
133 |
-
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
134 |
-
if mm_use_im_patch_token:
|
135 |
-
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
136 |
-
if mm_use_im_start_end:
|
137 |
-
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
138 |
-
model.resize_token_embeddings(len(tokenizer))
|
139 |
-
|
140 |
-
vision_tower = model.get_vision_tower()
|
141 |
-
if not vision_tower.is_loaded:
|
142 |
-
vision_tower.load_model(device_map=device_map)
|
143 |
-
if device_map != 'auto':
|
144 |
-
vision_tower.to(device=device_map, dtype=torch.float16)
|
145 |
-
image_processor = vision_tower.image_processor
|
146 |
-
|
147 |
-
if hasattr(model.config, "max_sequence_length"):
|
148 |
-
context_len = model.config.max_sequence_length
|
149 |
-
else:
|
150 |
-
context_len = 2048
|
151 |
-
|
152 |
-
return tokenizer, model, image_processor, context_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/model/language_model/llava_mistral.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
from typing import List, Optional, Tuple, Union
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
|
6 |
-
from transformers import AutoConfig, AutoModelForCausalLM, \
|
7 |
-
MistralConfig, MistralModel, MistralForCausalLM
|
8 |
-
|
9 |
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
10 |
-
from transformers.generation.utils import GenerateOutput
|
11 |
-
|
12 |
-
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
13 |
-
|
14 |
-
|
15 |
-
class LlavaMistralConfig(MistralConfig):
|
16 |
-
model_type = "llava_mistral"
|
17 |
-
|
18 |
-
|
19 |
-
class LlavaMistralModel(LlavaMetaModel, MistralModel):
|
20 |
-
config_class = LlavaMistralConfig
|
21 |
-
|
22 |
-
def __init__(self, config: MistralConfig):
|
23 |
-
super(LlavaMistralModel, self).__init__(config)
|
24 |
-
|
25 |
-
|
26 |
-
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
|
27 |
-
config_class = LlavaMistralConfig
|
28 |
-
|
29 |
-
def __init__(self, config):
|
30 |
-
super(MistralForCausalLM, self).__init__(config)
|
31 |
-
self.model = LlavaMistralModel(config)
|
32 |
-
|
33 |
-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
34 |
-
|
35 |
-
# Initialize weights and apply final processing
|
36 |
-
self.post_init()
|
37 |
-
|
38 |
-
def get_model(self):
|
39 |
-
return self.model
|
40 |
-
|
41 |
-
def forward(
|
42 |
-
self,
|
43 |
-
input_ids: torch.LongTensor = None,
|
44 |
-
attention_mask: Optional[torch.Tensor] = None,
|
45 |
-
position_ids: Optional[torch.LongTensor] = None,
|
46 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
47 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
48 |
-
labels: Optional[torch.LongTensor] = None,
|
49 |
-
use_cache: Optional[bool] = None,
|
50 |
-
output_attentions: Optional[bool] = None,
|
51 |
-
output_hidden_states: Optional[bool] = None,
|
52 |
-
images: Optional[torch.FloatTensor] = None,
|
53 |
-
image_sizes: Optional[List[List[int]]] = None,
|
54 |
-
return_dict: Optional[bool] = None,
|
55 |
-
) -> Union[Tuple, CausalLMOutputWithPast]:
|
56 |
-
|
57 |
-
if inputs_embeds is None:
|
58 |
-
(
|
59 |
-
input_ids,
|
60 |
-
position_ids,
|
61 |
-
attention_mask,
|
62 |
-
past_key_values,
|
63 |
-
inputs_embeds,
|
64 |
-
labels
|
65 |
-
) = self.prepare_inputs_labels_for_multimodal(
|
66 |
-
input_ids,
|
67 |
-
position_ids,
|
68 |
-
attention_mask,
|
69 |
-
past_key_values,
|
70 |
-
labels,
|
71 |
-
images,
|
72 |
-
image_sizes
|
73 |
-
)
|
74 |
-
|
75 |
-
return super().forward(
|
76 |
-
input_ids=input_ids,
|
77 |
-
attention_mask=attention_mask,
|
78 |
-
position_ids=position_ids,
|
79 |
-
past_key_values=past_key_values,
|
80 |
-
inputs_embeds=inputs_embeds,
|
81 |
-
labels=labels,
|
82 |
-
use_cache=use_cache,
|
83 |
-
output_attentions=output_attentions,
|
84 |
-
output_hidden_states=output_hidden_states,
|
85 |
-
return_dict=return_dict
|
86 |
-
)
|
87 |
-
|
88 |
-
@torch.no_grad()
|
89 |
-
def generate(
|
90 |
-
self,
|
91 |
-
inputs: Optional[torch.Tensor] = None,
|
92 |
-
images: Optional[torch.Tensor] = None,
|
93 |
-
image_sizes: Optional[torch.Tensor] = None,
|
94 |
-
**kwargs,
|
95 |
-
) -> Union[GenerateOutput, torch.LongTensor]:
|
96 |
-
position_ids = kwargs.pop("position_ids", None)
|
97 |
-
attention_mask = kwargs.pop("attention_mask", None)
|
98 |
-
if "inputs_embeds" in kwargs:
|
99 |
-
raise NotImplementedError("`inputs_embeds` is not supported")
|
100 |
-
|
101 |
-
if images is not None:
|
102 |
-
(
|
103 |
-
inputs,
|
104 |
-
position_ids,
|
105 |
-
attention_mask,
|
106 |
-
_,
|
107 |
-
inputs_embeds,
|
108 |
-
_
|
109 |
-
) = self.prepare_inputs_labels_for_multimodal(
|
110 |
-
inputs,
|
111 |
-
position_ids,
|
112 |
-
attention_mask,
|
113 |
-
None,
|
114 |
-
None,
|
115 |
-
images,
|
116 |
-
image_sizes=image_sizes
|
117 |
-
)
|
118 |
-
else:
|
119 |
-
inputs_embeds = self.get_model().embed_tokens(inputs)
|
120 |
-
|
121 |
-
return super().generate(
|
122 |
-
position_ids=position_ids,
|
123 |
-
attention_mask=attention_mask,
|
124 |
-
inputs_embeds=inputs_embeds,
|
125 |
-
**kwargs
|
126 |
-
)
|
127 |
-
|
128 |
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
129 |
-
inputs_embeds=None, **kwargs):
|
130 |
-
images = kwargs.pop("images", None)
|
131 |
-
image_sizes = kwargs.pop("image_sizes", None)
|
132 |
-
inputs = super().prepare_inputs_for_generation(
|
133 |
-
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
134 |
-
)
|
135 |
-
if images is not None:
|
136 |
-
inputs['images'] = images
|
137 |
-
if image_sizes is not None:
|
138 |
-
inputs['image_sizes'] = image_sizes
|
139 |
-
return inputs
|
140 |
-
|
141 |
-
|
142 |
-
AutoConfig.register("llava_mistral", LlavaMistralConfig)
|
143 |
-
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/model/llava_arch.py
DELETED
@@ -1,309 +0,0 @@
|
|
1 |
-
# Copyright 2023 Haotian Liu
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
|
16 |
-
from abc import ABC, abstractmethod
|
17 |
-
import os
|
18 |
-
from glob import glob
|
19 |
-
|
20 |
-
import torch
|
21 |
-
import torch.nn as nn
|
22 |
-
|
23 |
-
from .multimodal_encoder.builder import build_vision_tower
|
24 |
-
from .multimodal_projector.builder import build_vision_projector
|
25 |
-
|
26 |
-
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
27 |
-
|
28 |
-
|
29 |
-
class LlavaMetaModel:
|
30 |
-
|
31 |
-
def __init__(self, config):
|
32 |
-
super(LlavaMetaModel, self).__init__(config)
|
33 |
-
|
34 |
-
if hasattr(config, "mm_vision_tower"):
|
35 |
-
self.vision_tower = build_vision_tower(config, delay_load=True)
|
36 |
-
self.mm_projector = build_vision_projector(config)
|
37 |
-
|
38 |
-
def get_vision_tower(self):
|
39 |
-
vision_tower = getattr(self, 'vision_tower', None)
|
40 |
-
if type(vision_tower) is list:
|
41 |
-
vision_tower = vision_tower[0]
|
42 |
-
return vision_tower
|
43 |
-
|
44 |
-
def initialize_vision_modules(self, model_args, fsdp=None, embed_tokens=None):
|
45 |
-
vision_tower = model_args.vision_tower
|
46 |
-
mm_vision_select_layer = model_args.mm_vision_select_layer
|
47 |
-
mm_vision_select_feature = model_args.mm_vision_select_feature
|
48 |
-
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
49 |
-
|
50 |
-
self.config.mm_vision_tower = vision_tower
|
51 |
-
|
52 |
-
if self.get_vision_tower() is None:
|
53 |
-
vision_tower = build_vision_tower(model_args)
|
54 |
-
|
55 |
-
if fsdp is not None and len(fsdp) > 0:
|
56 |
-
self.vision_tower = [vision_tower]
|
57 |
-
else:
|
58 |
-
self.vision_tower = vision_tower
|
59 |
-
else:
|
60 |
-
if fsdp is not None and len(fsdp) > 0:
|
61 |
-
vision_tower = self.vision_tower[0]
|
62 |
-
else:
|
63 |
-
vision_tower = self.vision_tower
|
64 |
-
vision_tower.load_model()
|
65 |
-
|
66 |
-
self.config.use_mm_proj = True
|
67 |
-
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
68 |
-
self.config.mm_hidden_size = vision_tower.hidden_size
|
69 |
-
self.config.mm_vision_select_layer = mm_vision_select_layer
|
70 |
-
self.config.mm_vision_select_feature = mm_vision_select_feature
|
71 |
-
|
72 |
-
# add additional configs for segtok
|
73 |
-
self.config.feature_outs = model_args.feature_outs
|
74 |
-
self.config.img_size = model_args.img_size
|
75 |
-
self.config.vision_backbone = model_args.vision_backbone
|
76 |
-
self.config.segtok_posembed = model_args.segtok_posembed
|
77 |
-
|
78 |
-
if getattr(self, 'mm_projector', None) is None:
|
79 |
-
self.mm_projector = build_vision_projector(self.config)
|
80 |
-
else:
|
81 |
-
# In case it is frozen by LoRA
|
82 |
-
for p in self.mm_projector.parameters():
|
83 |
-
p.requires_grad = True
|
84 |
-
|
85 |
-
# Initialize last layer in mm_projector with weight=0 and bias=mean(embed_tokens)
|
86 |
-
if embed_tokens is not None:
|
87 |
-
embed_tokens_weight = embed_tokens.weight.data
|
88 |
-
self.mm_projector[-1].weight.data.zero_()
|
89 |
-
self.mm_projector[-1].bias.data.copy_(embed_tokens_weight.mean(dim=0))
|
90 |
-
|
91 |
-
if pretrain_mm_mlp_adapter is not None:
|
92 |
-
def get_w(weights, keyword):
|
93 |
-
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
94 |
-
|
95 |
-
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
96 |
-
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
97 |
-
|
98 |
-
# also load additional learnable parameters during feature alignment
|
99 |
-
checkpoint_folder = os.path.dirname(pretrain_mm_mlp_adapter)
|
100 |
-
ckpts = glob(f"{checkpoint_folder}/checkpoint-*", recursive = False)
|
101 |
-
if len(ckpts) > 0:
|
102 |
-
vision_module_weights = torch.load(f"{ckpts[-1]}/mm_projector.bin", map_location='cpu')
|
103 |
-
model_dict = get_w(vision_module_weights, 'vision_tower')
|
104 |
-
print(f"Loading vision module weights from {ckpts[-1]}/mm_projector.bin")
|
105 |
-
# print keys in model_dict
|
106 |
-
print(f"Loaded keys: {model_dict.keys()}")
|
107 |
-
self.vision_tower.load_state_dict(model_dict, strict=False)
|
108 |
-
|
109 |
-
class LlavaMetaForCausalLM(ABC):
|
110 |
-
|
111 |
-
@abstractmethod
|
112 |
-
def get_model(self):
|
113 |
-
pass
|
114 |
-
|
115 |
-
def get_vision_tower(self):
|
116 |
-
return self.get_model().get_vision_tower()
|
117 |
-
|
118 |
-
def encode_images(self, images):
|
119 |
-
image_features = self.get_model().get_vision_tower()(images)
|
120 |
-
image_features = self.get_model().mm_projector(image_features)
|
121 |
-
return image_features
|
122 |
-
|
123 |
-
def prepare_inputs_labels_for_multimodal(
|
124 |
-
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
|
125 |
-
):
|
126 |
-
vision_tower = self.get_vision_tower()
|
127 |
-
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
128 |
-
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
|
129 |
-
target_shape = past_key_values[-1][-1].shape[-2] + 1
|
130 |
-
attention_mask = torch.cat((attention_mask, torch.ones(
|
131 |
-
(attention_mask.shape[0], target_shape - attention_mask.shape[1]),
|
132 |
-
dtype=attention_mask.dtype,
|
133 |
-
device=attention_mask.device
|
134 |
-
)), dim=1)
|
135 |
-
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
136 |
-
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
137 |
-
|
138 |
-
if type(images) is list or images.ndim == 5:
|
139 |
-
concat_images = torch.cat([image for image in images], dim=0)
|
140 |
-
image_features = self.encode_images(concat_images)
|
141 |
-
split_sizes = [image.shape[0] for image in images]
|
142 |
-
image_features = torch.split(image_features, split_sizes, dim=0)
|
143 |
-
image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
|
144 |
-
else:
|
145 |
-
image_features = self.encode_images(images).to(self.device)
|
146 |
-
|
147 |
-
# TODO: image start / end is not implemented here to support pretraining.
|
148 |
-
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
149 |
-
raise NotImplementedError
|
150 |
-
|
151 |
-
# Let's just add dummy tensors if they do not exist,
|
152 |
-
# it is a headache to deal with None all the time.
|
153 |
-
# But it is not ideal, and if you have a better idea,
|
154 |
-
# please open an issue / submit a PR, thanks.
|
155 |
-
_labels = labels
|
156 |
-
_position_ids = position_ids
|
157 |
-
_attention_mask = attention_mask
|
158 |
-
|
159 |
-
if attention_mask is None:
|
160 |
-
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
161 |
-
else:
|
162 |
-
attention_mask = attention_mask.bool()
|
163 |
-
if position_ids is None:
|
164 |
-
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
165 |
-
|
166 |
-
if labels is None:
|
167 |
-
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
168 |
-
|
169 |
-
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
170 |
-
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
171 |
-
|
172 |
-
new_input_embeds = []
|
173 |
-
new_labels = []
|
174 |
-
cur_image_idx = 0
|
175 |
-
for batch_idx, cur_input_ids in enumerate(input_ids):
|
176 |
-
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
177 |
-
if num_images == 0:
|
178 |
-
cur_image_features = image_features[cur_image_idx]
|
179 |
-
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
180 |
-
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
181 |
-
new_input_embeds.append(cur_input_embeds)
|
182 |
-
new_labels.append(labels[batch_idx])
|
183 |
-
cur_image_idx += 1
|
184 |
-
continue
|
185 |
-
|
186 |
-
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
187 |
-
cur_input_ids_noim = []
|
188 |
-
cur_labels = labels[batch_idx]
|
189 |
-
cur_labels_noim = []
|
190 |
-
for i in range(len(image_token_indices) - 1):
|
191 |
-
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
|
192 |
-
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
|
193 |
-
|
194 |
-
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
195 |
-
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
196 |
-
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
197 |
-
cur_new_input_embeds = []
|
198 |
-
cur_new_labels = []
|
199 |
-
|
200 |
-
for i in range(num_images + 1):
|
201 |
-
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
202 |
-
cur_new_labels.append(cur_labels_noim[i])
|
203 |
-
if i < num_images:
|
204 |
-
cur_image_features = image_features[cur_image_idx]
|
205 |
-
cur_image_idx += 1
|
206 |
-
cur_new_input_embeds.append(cur_image_features)
|
207 |
-
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
208 |
-
|
209 |
-
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
210 |
-
cur_new_labels = torch.cat(cur_new_labels)
|
211 |
-
|
212 |
-
new_input_embeds.append(cur_new_input_embeds)
|
213 |
-
new_labels.append(cur_new_labels)
|
214 |
-
|
215 |
-
# Truncate sequences to max length as image embeddings can make the sequence longer
|
216 |
-
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
217 |
-
if tokenizer_model_max_length is not None:
|
218 |
-
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
219 |
-
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
220 |
-
|
221 |
-
# Combine them
|
222 |
-
max_len = max(x.shape[0] for x in new_input_embeds)
|
223 |
-
batch_size = len(new_input_embeds)
|
224 |
-
|
225 |
-
new_input_embeds_padded = []
|
226 |
-
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
227 |
-
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
228 |
-
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
229 |
-
|
230 |
-
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
231 |
-
cur_len = cur_new_embed.shape[0]
|
232 |
-
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
233 |
-
new_input_embeds_padded.append(torch.cat((
|
234 |
-
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
|
235 |
-
cur_new_embed
|
236 |
-
), dim=0))
|
237 |
-
if cur_len > 0:
|
238 |
-
new_labels_padded[i, -cur_len:] = cur_new_labels
|
239 |
-
attention_mask[i, -cur_len:] = True
|
240 |
-
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
241 |
-
else:
|
242 |
-
new_input_embeds_padded.append(torch.cat((
|
243 |
-
cur_new_embed,
|
244 |
-
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
|
245 |
-
), dim=0))
|
246 |
-
if cur_len > 0:
|
247 |
-
new_labels_padded[i, :cur_len] = cur_new_labels
|
248 |
-
attention_mask[i, :cur_len] = True
|
249 |
-
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
250 |
-
|
251 |
-
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
252 |
-
|
253 |
-
if _labels is None:
|
254 |
-
new_labels = None
|
255 |
-
else:
|
256 |
-
new_labels = new_labels_padded
|
257 |
-
|
258 |
-
if _attention_mask is None:
|
259 |
-
attention_mask = None
|
260 |
-
else:
|
261 |
-
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
262 |
-
|
263 |
-
if _position_ids is None:
|
264 |
-
position_ids = None
|
265 |
-
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
266 |
-
|
267 |
-
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
268 |
-
if model_args.mm_use_im_patch_token:
|
269 |
-
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
270 |
-
self.resize_token_embeddings(len(tokenizer))
|
271 |
-
|
272 |
-
if model_args.mm_use_im_start_end:
|
273 |
-
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
274 |
-
self.resize_token_embeddings(len(tokenizer))
|
275 |
-
|
276 |
-
if num_new_tokens > 0:
|
277 |
-
input_embeddings = self.get_input_embeddings().weight.data
|
278 |
-
output_embeddings = self.get_output_embeddings().weight.data
|
279 |
-
|
280 |
-
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
281 |
-
dim=0, keepdim=True)
|
282 |
-
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
283 |
-
dim=0, keepdim=True)
|
284 |
-
|
285 |
-
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
286 |
-
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
287 |
-
|
288 |
-
if model_args.tune_mm_mlp_adapter:
|
289 |
-
for p in self.get_input_embeddings().parameters():
|
290 |
-
p.requires_grad = True
|
291 |
-
for p in self.get_output_embeddings().parameters():
|
292 |
-
p.requires_grad = False
|
293 |
-
|
294 |
-
if model_args.pretrain_mm_mlp_adapter:
|
295 |
-
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
296 |
-
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
297 |
-
assert num_new_tokens == 2
|
298 |
-
if input_embeddings.shape == embed_tokens_weight.shape:
|
299 |
-
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
300 |
-
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
301 |
-
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
302 |
-
else:
|
303 |
-
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
304 |
-
elif model_args.mm_use_im_patch_token:
|
305 |
-
if model_args.tune_mm_mlp_adapter:
|
306 |
-
for p in self.get_input_embeddings().parameters():
|
307 |
-
p.requires_grad = False
|
308 |
-
for p in self.get_output_embeddings().parameters():
|
309 |
-
p.requires_grad = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/model/multimodal_encoder/builder.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from .clip_encoder import CLIPVisionTower
|
3 |
-
|
4 |
-
def build_vision_tower(vision_tower_cfg, **kwargs):
|
5 |
-
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
6 |
-
is_absolute_path_exists = os.path.exists(vision_tower)
|
7 |
-
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
|
8 |
-
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/model/multimodal_encoder/clip_encoder.py
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
|
4 |
-
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
5 |
-
|
6 |
-
|
7 |
-
class CLIPVisionTower(nn.Module):
|
8 |
-
def __init__(self, vision_tower, args, delay_load=False):
|
9 |
-
super().__init__()
|
10 |
-
|
11 |
-
self.is_loaded = False
|
12 |
-
|
13 |
-
self.vision_tower_name = vision_tower
|
14 |
-
self.select_layer = args.mm_vision_select_layer
|
15 |
-
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
16 |
-
|
17 |
-
if not delay_load:
|
18 |
-
self.load_model()
|
19 |
-
else:
|
20 |
-
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
21 |
-
|
22 |
-
def load_model(self):
|
23 |
-
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
24 |
-
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
25 |
-
self.vision_tower.requires_grad_(False)
|
26 |
-
|
27 |
-
self.is_loaded = True
|
28 |
-
|
29 |
-
def feature_select(self, image_forward_outs):
|
30 |
-
image_features = image_forward_outs.hidden_states[self.select_layer]
|
31 |
-
if self.select_feature == 'patch':
|
32 |
-
image_features = image_features[:, 1:]
|
33 |
-
elif self.select_feature == 'cls_patch':
|
34 |
-
image_features = image_features
|
35 |
-
else:
|
36 |
-
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
37 |
-
return image_features
|
38 |
-
|
39 |
-
@torch.no_grad()
|
40 |
-
def forward(self, images):
|
41 |
-
if type(images) is list:
|
42 |
-
image_features = []
|
43 |
-
for image in images:
|
44 |
-
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
45 |
-
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
46 |
-
image_features.append(image_feature)
|
47 |
-
else:
|
48 |
-
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
49 |
-
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
50 |
-
|
51 |
-
return image_features
|
52 |
-
|
53 |
-
@property
|
54 |
-
def dummy_feature(self):
|
55 |
-
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
56 |
-
|
57 |
-
@property
|
58 |
-
def dtype(self):
|
59 |
-
return self.vision_tower.dtype
|
60 |
-
|
61 |
-
@property
|
62 |
-
def device(self):
|
63 |
-
return self.vision_tower.device
|
64 |
-
|
65 |
-
@property
|
66 |
-
def config(self):
|
67 |
-
if self.is_loaded:
|
68 |
-
return self.vision_tower.config
|
69 |
-
else:
|
70 |
-
return self.cfg_only
|
71 |
-
|
72 |
-
@property
|
73 |
-
def hidden_size(self):
|
74 |
-
return self.config.hidden_size
|
75 |
-
|
76 |
-
@property
|
77 |
-
def num_patches(self):
|
78 |
-
return (self.config.image_size // self.config.patch_size) ** 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/model/multimodal_projector/builder.py
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import re
|
4 |
-
|
5 |
-
|
6 |
-
class IdentityMap(nn.Module):
|
7 |
-
def __init__(self):
|
8 |
-
super().__init__()
|
9 |
-
|
10 |
-
def forward(self, x, *args, **kwargs):
|
11 |
-
return x
|
12 |
-
|
13 |
-
@property
|
14 |
-
def config(self):
|
15 |
-
return {"mm_projector_type": 'identity'}
|
16 |
-
|
17 |
-
|
18 |
-
class SimpleResBlock(nn.Module):
|
19 |
-
def __init__(self, channels):
|
20 |
-
super().__init__()
|
21 |
-
self.pre_norm = nn.LayerNorm(channels)
|
22 |
-
|
23 |
-
self.proj = nn.Sequential(
|
24 |
-
nn.Linear(channels, channels),
|
25 |
-
nn.GELU(),
|
26 |
-
nn.Linear(channels, channels)
|
27 |
-
)
|
28 |
-
def forward(self, x):
|
29 |
-
x = self.pre_norm(x)
|
30 |
-
return x + self.proj(x)
|
31 |
-
|
32 |
-
|
33 |
-
def build_vision_projector(config, delay_load=False, **kwargs):
|
34 |
-
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
35 |
-
|
36 |
-
if projector_type == 'linear':
|
37 |
-
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
38 |
-
|
39 |
-
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
40 |
-
if mlp_gelu_match:
|
41 |
-
mlp_depth = int(mlp_gelu_match.group(1))
|
42 |
-
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
43 |
-
for _ in range(1, mlp_depth):
|
44 |
-
modules.append(nn.GELU())
|
45 |
-
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
46 |
-
return nn.Sequential(*modules)
|
47 |
-
|
48 |
-
if projector_type == 'identity':
|
49 |
-
return IdentityMap()
|
50 |
-
|
51 |
-
raise ValueError(f'Unknown projector type: {projector_type}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/serve/__init__.py
DELETED
File without changes
|
LLaVA-Med/llava/serve/cli.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import torch
|
3 |
-
|
4 |
-
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
5 |
-
from llava.conversation import conv_templates, SeparatorStyle
|
6 |
-
from llava.model.builder import load_pretrained_model
|
7 |
-
from llava.utils import disable_torch_init
|
8 |
-
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
9 |
-
|
10 |
-
from PIL import Image
|
11 |
-
|
12 |
-
import requests
|
13 |
-
from PIL import Image
|
14 |
-
from io import BytesIO
|
15 |
-
from transformers import TextStreamer
|
16 |
-
|
17 |
-
|
18 |
-
def load_image(image_file):
|
19 |
-
if image_file.startswith('http://') or image_file.startswith('https://'):
|
20 |
-
response = requests.get(image_file)
|
21 |
-
image = Image.open(BytesIO(response.content)).convert('RGB')
|
22 |
-
else:
|
23 |
-
image = Image.open(image_file).convert('RGB')
|
24 |
-
return image
|
25 |
-
|
26 |
-
|
27 |
-
def main(args):
|
28 |
-
# Model
|
29 |
-
disable_torch_init()
|
30 |
-
|
31 |
-
model_name = get_model_name_from_path(args.model_path)
|
32 |
-
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
33 |
-
|
34 |
-
if 'llama-2' in model_name.lower():
|
35 |
-
conv_mode = "llava_llama_2"
|
36 |
-
elif "v1" in model_name.lower():
|
37 |
-
conv_mode = "llava_v1"
|
38 |
-
elif "mpt" in model_name.lower():
|
39 |
-
conv_mode = "mpt"
|
40 |
-
else:
|
41 |
-
conv_mode = "llava_v0"
|
42 |
-
conv_mode = "mistral_instruct"
|
43 |
-
|
44 |
-
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
45 |
-
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
|
46 |
-
else:
|
47 |
-
args.conv_mode = conv_mode
|
48 |
-
|
49 |
-
conv = conv_templates[args.conv_mode].copy()
|
50 |
-
if "mpt" in model_name.lower():
|
51 |
-
roles = ('user', 'assistant')
|
52 |
-
else:
|
53 |
-
roles = conv.roles
|
54 |
-
|
55 |
-
image = load_image(args.image_file)
|
56 |
-
# Similar operation in model_worker.py
|
57 |
-
image_tensor = process_images([image], image_processor, model.config)
|
58 |
-
if type(image_tensor) is list:
|
59 |
-
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
|
60 |
-
else:
|
61 |
-
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
62 |
-
|
63 |
-
while True:
|
64 |
-
try:
|
65 |
-
inp = input(f"{roles[0]}: ")
|
66 |
-
except EOFError:
|
67 |
-
inp = ""
|
68 |
-
if not inp:
|
69 |
-
print("exit...")
|
70 |
-
break
|
71 |
-
|
72 |
-
print(f"{roles[1]}: ", end="")
|
73 |
-
|
74 |
-
if image is not None:
|
75 |
-
# first message
|
76 |
-
if model.config.mm_use_im_start_end:
|
77 |
-
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
|
78 |
-
else:
|
79 |
-
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
|
80 |
-
conv.append_message(conv.roles[0], inp)
|
81 |
-
image = None
|
82 |
-
else:
|
83 |
-
# later messages
|
84 |
-
conv.append_message(conv.roles[0], inp)
|
85 |
-
conv.append_message(conv.roles[1], None)
|
86 |
-
prompt = conv.get_prompt()
|
87 |
-
|
88 |
-
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
89 |
-
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
90 |
-
keywords = [stop_str]
|
91 |
-
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
92 |
-
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
93 |
-
|
94 |
-
with torch.inference_mode():
|
95 |
-
output_ids = model.generate(
|
96 |
-
input_ids,
|
97 |
-
images=image_tensor,
|
98 |
-
do_sample=True if args.temperature > 0 else False,
|
99 |
-
temperature=args.temperature,
|
100 |
-
max_new_tokens=args.max_new_tokens,
|
101 |
-
streamer=streamer,
|
102 |
-
use_cache=True,
|
103 |
-
stopping_criteria=[stopping_criteria])
|
104 |
-
|
105 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
106 |
-
conv.messages[-1][-1] = outputs
|
107 |
-
|
108 |
-
if args.debug:
|
109 |
-
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
|
110 |
-
|
111 |
-
|
112 |
-
if __name__ == "__main__":
|
113 |
-
parser = argparse.ArgumentParser()
|
114 |
-
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
115 |
-
parser.add_argument("--model-base", type=str, default=None)
|
116 |
-
parser.add_argument("--image-file", type=str, required=True)
|
117 |
-
parser.add_argument("--device", type=str, default="cuda")
|
118 |
-
parser.add_argument("--conv-mode", type=str, default=None)
|
119 |
-
parser.add_argument("--temperature", type=float, default=0.2)
|
120 |
-
parser.add_argument("--max-new-tokens", type=int, default=512)
|
121 |
-
parser.add_argument("--load-8bit", action="store_true")
|
122 |
-
parser.add_argument("--load-4bit", action="store_true")
|
123 |
-
parser.add_argument("--debug", action="store_true")
|
124 |
-
args = parser.parse_args()
|
125 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/serve/controller.py
DELETED
@@ -1,298 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
A controller manages distributed workers.
|
3 |
-
It sends worker addresses to clients.
|
4 |
-
"""
|
5 |
-
import argparse
|
6 |
-
import asyncio
|
7 |
-
import dataclasses
|
8 |
-
from enum import Enum, auto
|
9 |
-
import json
|
10 |
-
import logging
|
11 |
-
import time
|
12 |
-
from typing import List, Union
|
13 |
-
import threading
|
14 |
-
|
15 |
-
from fastapi import FastAPI, Request
|
16 |
-
from fastapi.responses import StreamingResponse
|
17 |
-
import numpy as np
|
18 |
-
import requests
|
19 |
-
import uvicorn
|
20 |
-
|
21 |
-
from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
22 |
-
from llava.utils import build_logger, server_error_msg
|
23 |
-
|
24 |
-
|
25 |
-
logger = build_logger("controller", "controller.log")
|
26 |
-
|
27 |
-
|
28 |
-
class DispatchMethod(Enum):
|
29 |
-
LOTTERY = auto()
|
30 |
-
SHORTEST_QUEUE = auto()
|
31 |
-
|
32 |
-
@classmethod
|
33 |
-
def from_str(cls, name):
|
34 |
-
if name == "lottery":
|
35 |
-
return cls.LOTTERY
|
36 |
-
elif name == "shortest_queue":
|
37 |
-
return cls.SHORTEST_QUEUE
|
38 |
-
else:
|
39 |
-
raise ValueError(f"Invalid dispatch method")
|
40 |
-
|
41 |
-
|
42 |
-
@dataclasses.dataclass
|
43 |
-
class WorkerInfo:
|
44 |
-
model_names: List[str]
|
45 |
-
speed: int
|
46 |
-
queue_length: int
|
47 |
-
check_heart_beat: bool
|
48 |
-
last_heart_beat: str
|
49 |
-
|
50 |
-
|
51 |
-
def heart_beat_controller(controller):
|
52 |
-
while True:
|
53 |
-
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
54 |
-
controller.remove_stable_workers_by_expiration()
|
55 |
-
|
56 |
-
|
57 |
-
class Controller:
|
58 |
-
def __init__(self, dispatch_method: str):
|
59 |
-
# Dict[str -> WorkerInfo]
|
60 |
-
self.worker_info = {}
|
61 |
-
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
62 |
-
|
63 |
-
self.heart_beat_thread = threading.Thread(
|
64 |
-
target=heart_beat_controller, args=(self,))
|
65 |
-
self.heart_beat_thread.start()
|
66 |
-
|
67 |
-
logger.info("Init controller")
|
68 |
-
|
69 |
-
def register_worker(self, worker_name: str, check_heart_beat: bool,
|
70 |
-
worker_status: dict):
|
71 |
-
if worker_name not in self.worker_info:
|
72 |
-
logger.info(f"Register a new worker: {worker_name}")
|
73 |
-
else:
|
74 |
-
logger.info(f"Register an existing worker: {worker_name}")
|
75 |
-
|
76 |
-
if not worker_status:
|
77 |
-
worker_status = self.get_worker_status(worker_name)
|
78 |
-
if not worker_status:
|
79 |
-
return False
|
80 |
-
|
81 |
-
self.worker_info[worker_name] = WorkerInfo(
|
82 |
-
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
|
83 |
-
check_heart_beat, time.time())
|
84 |
-
|
85 |
-
logger.info(f"Register done: {worker_name}, {worker_status}")
|
86 |
-
return True
|
87 |
-
|
88 |
-
def get_worker_status(self, worker_name: str):
|
89 |
-
try:
|
90 |
-
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
91 |
-
except requests.exceptions.RequestException as e:
|
92 |
-
logger.error(f"Get status fails: {worker_name}, {e}")
|
93 |
-
return None
|
94 |
-
|
95 |
-
if r.status_code != 200:
|
96 |
-
logger.error(f"Get status fails: {worker_name}, {r}")
|
97 |
-
return None
|
98 |
-
|
99 |
-
return r.json()
|
100 |
-
|
101 |
-
def remove_worker(self, worker_name: str):
|
102 |
-
del self.worker_info[worker_name]
|
103 |
-
|
104 |
-
def refresh_all_workers(self):
|
105 |
-
old_info = dict(self.worker_info)
|
106 |
-
self.worker_info = {}
|
107 |
-
|
108 |
-
for w_name, w_info in old_info.items():
|
109 |
-
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
110 |
-
logger.info(f"Remove stale worker: {w_name}")
|
111 |
-
|
112 |
-
def list_models(self):
|
113 |
-
model_names = set()
|
114 |
-
|
115 |
-
for w_name, w_info in self.worker_info.items():
|
116 |
-
model_names.update(w_info.model_names)
|
117 |
-
|
118 |
-
return list(model_names)
|
119 |
-
|
120 |
-
def get_worker_address(self, model_name: str):
|
121 |
-
if self.dispatch_method == DispatchMethod.LOTTERY:
|
122 |
-
worker_names = []
|
123 |
-
worker_speeds = []
|
124 |
-
for w_name, w_info in self.worker_info.items():
|
125 |
-
if model_name in w_info.model_names:
|
126 |
-
worker_names.append(w_name)
|
127 |
-
worker_speeds.append(w_info.speed)
|
128 |
-
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
129 |
-
norm = np.sum(worker_speeds)
|
130 |
-
if norm < 1e-4:
|
131 |
-
return ""
|
132 |
-
worker_speeds = worker_speeds / norm
|
133 |
-
if True: # Directly return address
|
134 |
-
pt = np.random.choice(np.arange(len(worker_names)),
|
135 |
-
p=worker_speeds)
|
136 |
-
worker_name = worker_names[pt]
|
137 |
-
return worker_name
|
138 |
-
|
139 |
-
# Check status before returning
|
140 |
-
while True:
|
141 |
-
pt = np.random.choice(np.arange(len(worker_names)),
|
142 |
-
p=worker_speeds)
|
143 |
-
worker_name = worker_names[pt]
|
144 |
-
|
145 |
-
if self.get_worker_status(worker_name):
|
146 |
-
break
|
147 |
-
else:
|
148 |
-
self.remove_worker(worker_name)
|
149 |
-
worker_speeds[pt] = 0
|
150 |
-
norm = np.sum(worker_speeds)
|
151 |
-
if norm < 1e-4:
|
152 |
-
return ""
|
153 |
-
worker_speeds = worker_speeds / norm
|
154 |
-
continue
|
155 |
-
return worker_name
|
156 |
-
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
157 |
-
worker_names = []
|
158 |
-
worker_qlen = []
|
159 |
-
for w_name, w_info in self.worker_info.items():
|
160 |
-
if model_name in w_info.model_names:
|
161 |
-
worker_names.append(w_name)
|
162 |
-
worker_qlen.append(w_info.queue_length / w_info.speed)
|
163 |
-
if len(worker_names) == 0:
|
164 |
-
return ""
|
165 |
-
min_index = np.argmin(worker_qlen)
|
166 |
-
w_name = worker_names[min_index]
|
167 |
-
self.worker_info[w_name].queue_length += 1
|
168 |
-
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
|
169 |
-
return w_name
|
170 |
-
else:
|
171 |
-
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
172 |
-
|
173 |
-
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
174 |
-
if worker_name not in self.worker_info:
|
175 |
-
logger.info(f"Receive unknown heart beat. {worker_name}")
|
176 |
-
return False
|
177 |
-
|
178 |
-
self.worker_info[worker_name].queue_length = queue_length
|
179 |
-
self.worker_info[worker_name].last_heart_beat = time.time()
|
180 |
-
logger.info(f"Receive heart beat. {worker_name}")
|
181 |
-
return True
|
182 |
-
|
183 |
-
def remove_stable_workers_by_expiration(self):
|
184 |
-
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
185 |
-
to_delete = []
|
186 |
-
for worker_name, w_info in self.worker_info.items():
|
187 |
-
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
188 |
-
to_delete.append(worker_name)
|
189 |
-
|
190 |
-
for worker_name in to_delete:
|
191 |
-
self.remove_worker(worker_name)
|
192 |
-
|
193 |
-
def worker_api_generate_stream(self, params):
|
194 |
-
worker_addr = self.get_worker_address(params["model"])
|
195 |
-
if not worker_addr:
|
196 |
-
logger.info(f"no worker: {params['model']}")
|
197 |
-
ret = {
|
198 |
-
"text": server_error_msg,
|
199 |
-
"error_code": 2,
|
200 |
-
}
|
201 |
-
yield json.dumps(ret).encode() + b"\0"
|
202 |
-
|
203 |
-
try:
|
204 |
-
response = requests.post(worker_addr + "/worker_generate_stream",
|
205 |
-
json=params, stream=True, timeout=5)
|
206 |
-
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
207 |
-
if chunk:
|
208 |
-
yield chunk + b"\0"
|
209 |
-
except requests.exceptions.RequestException as e:
|
210 |
-
logger.info(f"worker timeout: {worker_addr}")
|
211 |
-
ret = {
|
212 |
-
"text": server_error_msg,
|
213 |
-
"error_code": 3,
|
214 |
-
}
|
215 |
-
yield json.dumps(ret).encode() + b"\0"
|
216 |
-
|
217 |
-
|
218 |
-
# Let the controller act as a worker to achieve hierarchical
|
219 |
-
# management. This can be used to connect isolated sub networks.
|
220 |
-
def worker_api_get_status(self):
|
221 |
-
model_names = set()
|
222 |
-
speed = 0
|
223 |
-
queue_length = 0
|
224 |
-
|
225 |
-
for w_name in self.worker_info:
|
226 |
-
worker_status = self.get_worker_status(w_name)
|
227 |
-
if worker_status is not None:
|
228 |
-
model_names.update(worker_status["model_names"])
|
229 |
-
speed += worker_status["speed"]
|
230 |
-
queue_length += worker_status["queue_length"]
|
231 |
-
|
232 |
-
return {
|
233 |
-
"model_names": list(model_names),
|
234 |
-
"speed": speed,
|
235 |
-
"queue_length": queue_length,
|
236 |
-
}
|
237 |
-
|
238 |
-
|
239 |
-
app = FastAPI()
|
240 |
-
|
241 |
-
|
242 |
-
@app.post("/register_worker")
|
243 |
-
async def register_worker(request: Request):
|
244 |
-
data = await request.json()
|
245 |
-
controller.register_worker(
|
246 |
-
data["worker_name"], data["check_heart_beat"],
|
247 |
-
data.get("worker_status", None))
|
248 |
-
|
249 |
-
|
250 |
-
@app.post("/refresh_all_workers")
|
251 |
-
async def refresh_all_workers():
|
252 |
-
models = controller.refresh_all_workers()
|
253 |
-
|
254 |
-
|
255 |
-
@app.post("/list_models")
|
256 |
-
async def list_models():
|
257 |
-
models = controller.list_models()
|
258 |
-
return {"models": models}
|
259 |
-
|
260 |
-
|
261 |
-
@app.post("/get_worker_address")
|
262 |
-
async def get_worker_address(request: Request):
|
263 |
-
data = await request.json()
|
264 |
-
addr = controller.get_worker_address(data["model"])
|
265 |
-
return {"address": addr}
|
266 |
-
|
267 |
-
|
268 |
-
@app.post("/receive_heart_beat")
|
269 |
-
async def receive_heart_beat(request: Request):
|
270 |
-
data = await request.json()
|
271 |
-
exist = controller.receive_heart_beat(
|
272 |
-
data["worker_name"], data["queue_length"])
|
273 |
-
return {"exist": exist}
|
274 |
-
|
275 |
-
|
276 |
-
@app.post("/worker_generate_stream")
|
277 |
-
async def worker_api_generate_stream(request: Request):
|
278 |
-
params = await request.json()
|
279 |
-
generator = controller.worker_api_generate_stream(params)
|
280 |
-
return StreamingResponse(generator)
|
281 |
-
|
282 |
-
|
283 |
-
@app.post("/worker_get_status")
|
284 |
-
async def worker_api_get_status(request: Request):
|
285 |
-
return controller.worker_api_get_status()
|
286 |
-
|
287 |
-
|
288 |
-
if __name__ == "__main__":
|
289 |
-
parser = argparse.ArgumentParser()
|
290 |
-
parser.add_argument("--host", type=str, default="localhost")
|
291 |
-
parser.add_argument("--port", type=int, default=21001)
|
292 |
-
parser.add_argument("--dispatch-method", type=str, choices=[
|
293 |
-
"lottery", "shortest_queue"], default="shortest_queue")
|
294 |
-
args = parser.parse_args()
|
295 |
-
logger.info(f"args: {args}")
|
296 |
-
|
297 |
-
controller = Controller(args.dispatch_method)
|
298 |
-
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/serve/examples/bio_patch.png
DELETED
Binary file (214 kB)
|
|
LLaVA-Med/llava/serve/examples/extreme_ironing.jpg
DELETED
Binary file (62.6 kB)
|
|
LLaVA-Med/llava/serve/examples/med_img_1.png
DELETED
Binary file (319 kB)
|
|
LLaVA-Med/llava/serve/examples/synpic32933.jpg
DELETED
Binary file (69.8 kB)
|
|
LLaVA-Med/llava/serve/examples/synpic42202.jpg
DELETED
Binary file (86.1 kB)
|
|
LLaVA-Med/llava/serve/examples/waterview.jpg
DELETED
Binary file (95.5 kB)
|
|
LLaVA-Med/llava/serve/examples/xy_chromosome.jpg
DELETED
Binary file (53.9 kB)
|
|
LLaVA-Med/llava/serve/gradio_web_server.py
DELETED
@@ -1,477 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import datetime
|
3 |
-
import json
|
4 |
-
import os
|
5 |
-
import time
|
6 |
-
|
7 |
-
import gradio as gr
|
8 |
-
import requests
|
9 |
-
|
10 |
-
from llava.conversation import (default_conversation, conv_templates,
|
11 |
-
SeparatorStyle)
|
12 |
-
from llava.constants import LOGDIR
|
13 |
-
from llava.utils import (build_logger, server_error_msg,
|
14 |
-
violates_moderation, moderation_msg)
|
15 |
-
import hashlib
|
16 |
-
|
17 |
-
|
18 |
-
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
19 |
-
|
20 |
-
headers = {"User-Agent": "LLaVA-Med Client"}
|
21 |
-
|
22 |
-
no_change_btn = gr.Button.update()
|
23 |
-
enable_btn = gr.Button.update(interactive=True)
|
24 |
-
disable_btn = gr.Button.update(interactive=False)
|
25 |
-
|
26 |
-
priority = {
|
27 |
-
"vicuna-13b": "aaaaaaa",
|
28 |
-
"koala-13b": "aaaaaab",
|
29 |
-
}
|
30 |
-
|
31 |
-
|
32 |
-
def get_conv_log_filename():
|
33 |
-
t = datetime.datetime.now()
|
34 |
-
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
35 |
-
return name
|
36 |
-
|
37 |
-
|
38 |
-
def get_model_list():
|
39 |
-
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
40 |
-
assert ret.status_code == 200
|
41 |
-
ret = requests.post(args.controller_url + "/list_models")
|
42 |
-
models = ret.json()["models"]
|
43 |
-
models.sort(key=lambda x: priority.get(x, x))
|
44 |
-
logger.info(f"Models: {models}")
|
45 |
-
return models
|
46 |
-
|
47 |
-
|
48 |
-
get_window_url_params = """
|
49 |
-
function() {
|
50 |
-
const params = new URLSearchParams(window.location.search);
|
51 |
-
url_params = Object.fromEntries(params);
|
52 |
-
console.log(url_params);
|
53 |
-
return url_params;
|
54 |
-
}
|
55 |
-
"""
|
56 |
-
|
57 |
-
|
58 |
-
def load_demo(url_params, request: gr.Request):
|
59 |
-
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
60 |
-
|
61 |
-
dropdown_update = gr.Dropdown.update(visible=True)
|
62 |
-
if "model" in url_params:
|
63 |
-
model = url_params["model"]
|
64 |
-
if model in models:
|
65 |
-
dropdown_update = gr.Dropdown.update(
|
66 |
-
value=model, visible=True)
|
67 |
-
|
68 |
-
state = default_conversation.copy()
|
69 |
-
return state, dropdown_update
|
70 |
-
|
71 |
-
|
72 |
-
def load_demo_refresh_model_list(request: gr.Request):
|
73 |
-
logger.info(f"load_demo. ip: {request.client.host}")
|
74 |
-
models = get_model_list()
|
75 |
-
state = default_conversation.copy()
|
76 |
-
dropdown_update = gr.Dropdown.update(
|
77 |
-
choices=models,
|
78 |
-
value=models[0] if len(models) > 0 else ""
|
79 |
-
)
|
80 |
-
return state, dropdown_update
|
81 |
-
|
82 |
-
|
83 |
-
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
84 |
-
with open(get_conv_log_filename(), "a") as fout:
|
85 |
-
data = {
|
86 |
-
"tstamp": round(time.time(), 4),
|
87 |
-
"type": vote_type,
|
88 |
-
"model": model_selector,
|
89 |
-
"state": state.dict(),
|
90 |
-
"ip": request.client.host,
|
91 |
-
}
|
92 |
-
fout.write(json.dumps(data) + "\n")
|
93 |
-
|
94 |
-
|
95 |
-
def upvote_last_response(state, model_selector, request: gr.Request):
|
96 |
-
logger.info(f"upvote. ip: {request.client.host}")
|
97 |
-
vote_last_response(state, "upvote", model_selector, request)
|
98 |
-
return ("",) + (disable_btn,) * 3
|
99 |
-
|
100 |
-
|
101 |
-
def downvote_last_response(state, model_selector, request: gr.Request):
|
102 |
-
logger.info(f"downvote. ip: {request.client.host}")
|
103 |
-
vote_last_response(state, "downvote", model_selector, request)
|
104 |
-
return ("",) + (disable_btn,) * 3
|
105 |
-
|
106 |
-
|
107 |
-
def flag_last_response(state, model_selector, request: gr.Request):
|
108 |
-
logger.info(f"flag. ip: {request.client.host}")
|
109 |
-
vote_last_response(state, "flag", model_selector, request)
|
110 |
-
return ("",) + (disable_btn,) * 3
|
111 |
-
|
112 |
-
|
113 |
-
def regenerate(state, image_process_mode, request: gr.Request):
|
114 |
-
logger.info(f"regenerate. ip: {request.client.host}")
|
115 |
-
state.messages[-1][-1] = None
|
116 |
-
prev_human_msg = state.messages[-2]
|
117 |
-
if type(prev_human_msg[1]) in (tuple, list):
|
118 |
-
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
119 |
-
state.skip_next = False
|
120 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
121 |
-
|
122 |
-
|
123 |
-
def clear_history(request: gr.Request):
|
124 |
-
logger.info(f"clear_history. ip: {request.client.host}")
|
125 |
-
state = default_conversation.copy()
|
126 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
127 |
-
|
128 |
-
|
129 |
-
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
130 |
-
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
131 |
-
if len(text) <= 0 and image is None:
|
132 |
-
state.skip_next = True
|
133 |
-
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
134 |
-
if args.moderate:
|
135 |
-
flagged = violates_moderation(text)
|
136 |
-
if flagged:
|
137 |
-
state.skip_next = True
|
138 |
-
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
139 |
-
no_change_btn,) * 5
|
140 |
-
|
141 |
-
text = text[:1536] # Hard cut-off
|
142 |
-
if image is not None:
|
143 |
-
text = text[:1200] # Hard cut-off for images
|
144 |
-
if '<image>' not in text:
|
145 |
-
# text = '<Image><image></Image>' + text
|
146 |
-
text = text + '\n<image>'
|
147 |
-
text = (text, image, image_process_mode)
|
148 |
-
if len(state.get_images(return_pil=True)) > 0:
|
149 |
-
state = default_conversation.copy()
|
150 |
-
state.append_message(state.roles[0], text)
|
151 |
-
state.append_message(state.roles[1], None)
|
152 |
-
state.skip_next = False
|
153 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
154 |
-
|
155 |
-
|
156 |
-
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
157 |
-
logger.info(f"http_bot. ip: {request.client.host}")
|
158 |
-
start_tstamp = time.time()
|
159 |
-
model_name = model_selector
|
160 |
-
|
161 |
-
if state.skip_next:
|
162 |
-
# This generate call is skipped due to invalid inputs
|
163 |
-
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
164 |
-
return
|
165 |
-
|
166 |
-
if len(state.messages) == state.offset + 2:
|
167 |
-
# First round of conversation
|
168 |
-
if "llava" in model_name.lower():
|
169 |
-
if 'llama-2' in model_name.lower():
|
170 |
-
template_name = "llava_llama_2"
|
171 |
-
elif "v1" in model_name.lower():
|
172 |
-
if 'mmtag' in model_name.lower():
|
173 |
-
template_name = "v1_mmtag"
|
174 |
-
elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
|
175 |
-
template_name = "v1_mmtag"
|
176 |
-
else:
|
177 |
-
template_name = "llava_v1"
|
178 |
-
elif "mpt" in model_name.lower():
|
179 |
-
template_name = "mpt"
|
180 |
-
else:
|
181 |
-
if 'mmtag' in model_name.lower():
|
182 |
-
template_name = "v0_mmtag"
|
183 |
-
elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
|
184 |
-
template_name = "v0_mmtag"
|
185 |
-
else:
|
186 |
-
template_name = "llava_v0"
|
187 |
-
elif "mpt" in model_name:
|
188 |
-
template_name = "mpt_text"
|
189 |
-
elif "llama-2" in model_name:
|
190 |
-
template_name = "llama_2"
|
191 |
-
else:
|
192 |
-
template_name = "vicuna_v1"
|
193 |
-
template_name = "mistral_instruct" # FIXME: overwrite
|
194 |
-
new_state = conv_templates[template_name].copy()
|
195 |
-
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
196 |
-
new_state.append_message(new_state.roles[1], None)
|
197 |
-
state = new_state
|
198 |
-
|
199 |
-
# Query worker address
|
200 |
-
controller_url = args.controller_url
|
201 |
-
ret = requests.post(controller_url + "/get_worker_address",
|
202 |
-
json={"model": model_name})
|
203 |
-
worker_addr = ret.json()["address"]
|
204 |
-
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
205 |
-
|
206 |
-
# No available worker
|
207 |
-
if worker_addr == "":
|
208 |
-
state.messages[-1][-1] = server_error_msg
|
209 |
-
yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
210 |
-
return
|
211 |
-
|
212 |
-
# Construct prompt
|
213 |
-
prompt = state.get_prompt()
|
214 |
-
|
215 |
-
all_images = state.get_images(return_pil=True)
|
216 |
-
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
217 |
-
for image, hash in zip(all_images, all_image_hash):
|
218 |
-
t = datetime.datetime.now()
|
219 |
-
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
220 |
-
if not os.path.isfile(filename):
|
221 |
-
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
222 |
-
image.save(filename)
|
223 |
-
|
224 |
-
# Make requests
|
225 |
-
pload = {
|
226 |
-
"model": model_name,
|
227 |
-
"prompt": prompt,
|
228 |
-
"temperature": float(temperature),
|
229 |
-
"top_p": float(top_p),
|
230 |
-
"max_new_tokens": min(int(max_new_tokens), 1536),
|
231 |
-
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
|
232 |
-
"images": f'List of {len(state.get_images())} images: {all_image_hash}',
|
233 |
-
}
|
234 |
-
logger.info(f"==== request ====\n{pload}")
|
235 |
-
|
236 |
-
pload['images'] = state.get_images()
|
237 |
-
|
238 |
-
state.messages[-1][-1] = "▌"
|
239 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
240 |
-
|
241 |
-
try:
|
242 |
-
# Stream output
|
243 |
-
response = requests.post(worker_addr + "/worker_generate_stream",
|
244 |
-
headers=headers, json=pload, stream=True, timeout=10)
|
245 |
-
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
246 |
-
if chunk:
|
247 |
-
data = json.loads(chunk.decode())
|
248 |
-
if data["error_code"] == 0:
|
249 |
-
output = data["text"][len(prompt):].strip()
|
250 |
-
state.messages[-1][-1] = output + "▌"
|
251 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
252 |
-
else:
|
253 |
-
output = data["text"] + f" (error_code: {data['error_code']})"
|
254 |
-
state.messages[-1][-1] = output
|
255 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
256 |
-
return
|
257 |
-
time.sleep(0.03)
|
258 |
-
except requests.exceptions.RequestException as e:
|
259 |
-
state.messages[-1][-1] = server_error_msg
|
260 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
261 |
-
return
|
262 |
-
|
263 |
-
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
264 |
-
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
265 |
-
|
266 |
-
finish_tstamp = time.time()
|
267 |
-
logger.info(f"{output}")
|
268 |
-
|
269 |
-
with open(get_conv_log_filename(), "a") as fout:
|
270 |
-
data = {
|
271 |
-
"tstamp": round(finish_tstamp, 4),
|
272 |
-
"type": "chat",
|
273 |
-
"model": model_name,
|
274 |
-
"start": round(start_tstamp, 4),
|
275 |
-
"finish": round(finish_tstamp, 4),
|
276 |
-
"state": state.dict(),
|
277 |
-
"images": all_image_hash,
|
278 |
-
"ip": request.client.host,
|
279 |
-
}
|
280 |
-
fout.write(json.dumps(data) + "\n")
|
281 |
-
|
282 |
-
|
283 |
-
title_markdown = ("""
|
284 |
-
# 🌋 LLaVA-Med: Large Language and Vision Assistant for Medical Research
|
285 |
-
[[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)
|
286 |
-
""")
|
287 |
-
|
288 |
-
tos_markdown = ("""
|
289 |
-
### Terms of use
|
290 |
-
By using this service, users are required to agree to the following terms:
|
291 |
-
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
292 |
-
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
293 |
-
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
294 |
-
""")
|
295 |
-
|
296 |
-
|
297 |
-
learn_more_markdown = ("""
|
298 |
-
### License
|
299 |
-
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
300 |
-
""")
|
301 |
-
|
302 |
-
block_css = """
|
303 |
-
|
304 |
-
#buttons button {
|
305 |
-
min-width: min(120px,100%);
|
306 |
-
}
|
307 |
-
|
308 |
-
"""
|
309 |
-
|
310 |
-
def build_demo(embed_mode):
|
311 |
-
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
312 |
-
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
|
313 |
-
state = gr.State()
|
314 |
-
|
315 |
-
if not embed_mode:
|
316 |
-
gr.Markdown(title_markdown)
|
317 |
-
|
318 |
-
with gr.Row():
|
319 |
-
with gr.Column(scale=3):
|
320 |
-
with gr.Row(elem_id="model_selector_row"):
|
321 |
-
model_selector = gr.Dropdown(
|
322 |
-
choices=models,
|
323 |
-
value=models[0] if len(models) > 0 else "",
|
324 |
-
interactive=True,
|
325 |
-
show_label=False,
|
326 |
-
container=False)
|
327 |
-
|
328 |
-
imagebox = gr.Image(type="pil")
|
329 |
-
image_process_mode = gr.Radio(
|
330 |
-
["Crop", "Resize", "Pad", "Default"],
|
331 |
-
value="Default",
|
332 |
-
label="Preprocess for non-square image", visible=False)
|
333 |
-
|
334 |
-
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
335 |
-
gr.Examples(examples=[
|
336 |
-
[f"{cur_dir}/examples/bio_patch.png", "What is this image about?"],
|
337 |
-
[f"{cur_dir}/examples/med_img_1.png", "Can you describe the image in details?"],
|
338 |
-
[f"{cur_dir}/examples/xy_chromosome.jpg", "Can you describe the image in details?"],
|
339 |
-
[f"{cur_dir}/examples/synpic42202.jpg", "Is there evidence of an aortic aneurysm? Please choose from the following two options: [yes, no]?"], # answer" yes
|
340 |
-
[f"{cur_dir}/examples/synpic32933.jpg", "What is the abnormality by the right hemidiaphragm?"], # answer: free air
|
341 |
-
[f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
|
342 |
-
[f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
|
343 |
-
], inputs=[imagebox, textbox])
|
344 |
-
|
345 |
-
with gr.Accordion("Parameters", open=False) as parameter_row:
|
346 |
-
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
|
347 |
-
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
348 |
-
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
349 |
-
|
350 |
-
with gr.Column(scale=8):
|
351 |
-
chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA-Med Chatbot", height=550)
|
352 |
-
with gr.Row():
|
353 |
-
with gr.Column(scale=8):
|
354 |
-
textbox.render()
|
355 |
-
with gr.Column(scale=1, min_width=50):
|
356 |
-
submit_btn = gr.Button(value="Send", variant="primary")
|
357 |
-
with gr.Row(elem_id="buttons") as button_row:
|
358 |
-
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
359 |
-
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
360 |
-
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
361 |
-
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
362 |
-
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
363 |
-
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
364 |
-
|
365 |
-
if not embed_mode:
|
366 |
-
gr.Markdown(tos_markdown)
|
367 |
-
gr.Markdown(learn_more_markdown)
|
368 |
-
url_params = gr.JSON(visible=False)
|
369 |
-
|
370 |
-
# Register listeners
|
371 |
-
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
372 |
-
upvote_btn.click(
|
373 |
-
upvote_last_response,
|
374 |
-
[state, model_selector],
|
375 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
376 |
-
queue=False
|
377 |
-
)
|
378 |
-
downvote_btn.click(
|
379 |
-
downvote_last_response,
|
380 |
-
[state, model_selector],
|
381 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
382 |
-
queue=False
|
383 |
-
)
|
384 |
-
flag_btn.click(
|
385 |
-
flag_last_response,
|
386 |
-
[state, model_selector],
|
387 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
388 |
-
queue=False
|
389 |
-
)
|
390 |
-
|
391 |
-
regenerate_btn.click(
|
392 |
-
regenerate,
|
393 |
-
[state, image_process_mode],
|
394 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
395 |
-
queue=False
|
396 |
-
).then(
|
397 |
-
http_bot,
|
398 |
-
[state, model_selector, temperature, top_p, max_output_tokens],
|
399 |
-
[state, chatbot] + btn_list
|
400 |
-
)
|
401 |
-
|
402 |
-
clear_btn.click(
|
403 |
-
clear_history,
|
404 |
-
None,
|
405 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
406 |
-
queue=False
|
407 |
-
)
|
408 |
-
|
409 |
-
textbox.submit(
|
410 |
-
add_text,
|
411 |
-
[state, textbox, imagebox, image_process_mode],
|
412 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
413 |
-
queue=False
|
414 |
-
).then(
|
415 |
-
http_bot,
|
416 |
-
[state, model_selector, temperature, top_p, max_output_tokens],
|
417 |
-
[state, chatbot] + btn_list
|
418 |
-
)
|
419 |
-
|
420 |
-
submit_btn.click(
|
421 |
-
add_text,
|
422 |
-
[state, textbox, imagebox, image_process_mode],
|
423 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
424 |
-
queue=False
|
425 |
-
).then(
|
426 |
-
http_bot,
|
427 |
-
[state, model_selector, temperature, top_p, max_output_tokens],
|
428 |
-
[state, chatbot] + btn_list
|
429 |
-
)
|
430 |
-
|
431 |
-
if args.model_list_mode == "once":
|
432 |
-
demo.load(
|
433 |
-
load_demo,
|
434 |
-
[url_params],
|
435 |
-
[state, model_selector],
|
436 |
-
_js=get_window_url_params,
|
437 |
-
queue=False
|
438 |
-
)
|
439 |
-
elif args.model_list_mode == "reload":
|
440 |
-
demo.load(
|
441 |
-
load_demo_refresh_model_list,
|
442 |
-
None,
|
443 |
-
[state, model_selector],
|
444 |
-
queue=False
|
445 |
-
)
|
446 |
-
else:
|
447 |
-
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
448 |
-
|
449 |
-
return demo
|
450 |
-
|
451 |
-
|
452 |
-
if __name__ == "__main__":
|
453 |
-
parser = argparse.ArgumentParser()
|
454 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
455 |
-
parser.add_argument("--port", type=int)
|
456 |
-
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
457 |
-
parser.add_argument("--concurrency-count", type=int, default=10)
|
458 |
-
parser.add_argument("--model-list-mode", type=str, default="once",
|
459 |
-
choices=["once", "reload"])
|
460 |
-
parser.add_argument("--share", action="store_true")
|
461 |
-
parser.add_argument("--moderate", action="store_true")
|
462 |
-
parser.add_argument("--embed", action="store_true")
|
463 |
-
args = parser.parse_args()
|
464 |
-
logger.info(f"args: {args}")
|
465 |
-
|
466 |
-
models = get_model_list()
|
467 |
-
|
468 |
-
logger.info(args)
|
469 |
-
demo = build_demo(args.embed)
|
470 |
-
demo.queue(
|
471 |
-
concurrency_count=args.concurrency_count,
|
472 |
-
api_open=False
|
473 |
-
).launch(
|
474 |
-
server_name=args.host,
|
475 |
-
server_port=args.port,
|
476 |
-
share=args.share
|
477 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/serve/model_worker.py
DELETED
@@ -1,285 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
A model worker executes the model.
|
3 |
-
"""
|
4 |
-
import argparse
|
5 |
-
import asyncio
|
6 |
-
import json
|
7 |
-
import time
|
8 |
-
import threading
|
9 |
-
import uuid
|
10 |
-
|
11 |
-
from fastapi import FastAPI, Request, BackgroundTasks
|
12 |
-
from fastapi.responses import StreamingResponse
|
13 |
-
import requests
|
14 |
-
import torch
|
15 |
-
import uvicorn
|
16 |
-
from functools import partial
|
17 |
-
|
18 |
-
from llava.constants import WORKER_HEART_BEAT_INTERVAL
|
19 |
-
from llava.utils import (build_logger, server_error_msg,
|
20 |
-
pretty_print_semaphore)
|
21 |
-
from llava.model.builder import load_pretrained_model
|
22 |
-
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
|
23 |
-
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
-
from transformers import TextIteratorStreamer
|
25 |
-
from threading import Thread
|
26 |
-
|
27 |
-
|
28 |
-
GB = 1 << 30
|
29 |
-
|
30 |
-
worker_id = str(uuid.uuid4())[:6]
|
31 |
-
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
32 |
-
global_counter = 0
|
33 |
-
|
34 |
-
model_semaphore = None
|
35 |
-
|
36 |
-
|
37 |
-
def heart_beat_worker(controller):
|
38 |
-
|
39 |
-
while True:
|
40 |
-
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
41 |
-
controller.send_heart_beat()
|
42 |
-
|
43 |
-
|
44 |
-
class ModelWorker:
|
45 |
-
def __init__(self, controller_addr, worker_addr,
|
46 |
-
worker_id, no_register,
|
47 |
-
model_path, model_base, model_name,
|
48 |
-
load_8bit, load_4bit, device):
|
49 |
-
self.controller_addr = controller_addr
|
50 |
-
self.worker_addr = worker_addr
|
51 |
-
self.worker_id = worker_id
|
52 |
-
if model_path.endswith("/"):
|
53 |
-
model_path = model_path[:-1]
|
54 |
-
if model_name is None:
|
55 |
-
model_paths = model_path.split("/")
|
56 |
-
if model_paths[-1].startswith('checkpoint-'):
|
57 |
-
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
58 |
-
else:
|
59 |
-
self.model_name = model_paths[-1]
|
60 |
-
else:
|
61 |
-
self.model_name = model_name
|
62 |
-
|
63 |
-
self.device = device
|
64 |
-
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
65 |
-
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
66 |
-
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
67 |
-
self.is_multimodal = 'llava' in self.model_name.lower()
|
68 |
-
|
69 |
-
if not no_register:
|
70 |
-
self.register_to_controller()
|
71 |
-
self.heart_beat_thread = threading.Thread(
|
72 |
-
target=heart_beat_worker, args=(self,))
|
73 |
-
self.heart_beat_thread.start()
|
74 |
-
|
75 |
-
def register_to_controller(self):
|
76 |
-
logger.info("Register to controller")
|
77 |
-
|
78 |
-
url = self.controller_addr + "/register_worker"
|
79 |
-
data = {
|
80 |
-
"worker_name": self.worker_addr,
|
81 |
-
"check_heart_beat": True,
|
82 |
-
"worker_status": self.get_status()
|
83 |
-
}
|
84 |
-
r = requests.post(url, json=data)
|
85 |
-
assert r.status_code == 200
|
86 |
-
|
87 |
-
def send_heart_beat(self):
|
88 |
-
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
89 |
-
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
90 |
-
f"global_counter: {global_counter}")
|
91 |
-
|
92 |
-
url = self.controller_addr + "/receive_heart_beat"
|
93 |
-
|
94 |
-
while True:
|
95 |
-
try:
|
96 |
-
ret = requests.post(url, json={
|
97 |
-
"worker_name": self.worker_addr,
|
98 |
-
"queue_length": self.get_queue_length()}, timeout=5)
|
99 |
-
exist = ret.json()["exist"]
|
100 |
-
break
|
101 |
-
except requests.exceptions.RequestException as e:
|
102 |
-
logger.error(f"heart beat error: {e}")
|
103 |
-
time.sleep(5)
|
104 |
-
|
105 |
-
if not exist:
|
106 |
-
self.register_to_controller()
|
107 |
-
|
108 |
-
def get_queue_length(self):
|
109 |
-
if model_semaphore is None:
|
110 |
-
return 0
|
111 |
-
else:
|
112 |
-
return args.limit_model_concurrency - model_semaphore._value + (len(
|
113 |
-
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
114 |
-
|
115 |
-
def get_status(self):
|
116 |
-
return {
|
117 |
-
"model_names": [self.model_name],
|
118 |
-
"speed": 1,
|
119 |
-
"queue_length": self.get_queue_length(),
|
120 |
-
}
|
121 |
-
|
122 |
-
@torch.inference_mode()
|
123 |
-
def generate_stream(self, params):
|
124 |
-
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
125 |
-
|
126 |
-
prompt = params["prompt"]
|
127 |
-
ori_prompt = prompt
|
128 |
-
images = params.get("images", None)
|
129 |
-
num_image_tokens = 0
|
130 |
-
if images is not None and len(images) > 0 and self.is_multimodal:
|
131 |
-
if len(images) > 0:
|
132 |
-
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
133 |
-
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
134 |
-
|
135 |
-
images = [load_image_from_base64(image) for image in images]
|
136 |
-
images = process_images(images, image_processor, model.config)
|
137 |
-
|
138 |
-
if type(images) is list:
|
139 |
-
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
|
140 |
-
else:
|
141 |
-
images = images.to(self.model.device, dtype=torch.float16)
|
142 |
-
|
143 |
-
replace_token = DEFAULT_IMAGE_TOKEN
|
144 |
-
if getattr(self.model.config, 'mm_use_im_start_end', False):
|
145 |
-
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
146 |
-
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
147 |
-
|
148 |
-
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
|
149 |
-
else:
|
150 |
-
images = None
|
151 |
-
image_args = {"images": images}
|
152 |
-
else:
|
153 |
-
images = None
|
154 |
-
image_args = {}
|
155 |
-
|
156 |
-
temperature = float(params.get("temperature", 1.0))
|
157 |
-
top_p = float(params.get("top_p", 1.0))
|
158 |
-
max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
|
159 |
-
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
160 |
-
stop_str = params.get("stop", None)
|
161 |
-
do_sample = True if temperature > 0.001 else False
|
162 |
-
|
163 |
-
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
164 |
-
keywords = [stop_str]
|
165 |
-
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
166 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
167 |
-
|
168 |
-
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
169 |
-
|
170 |
-
if max_new_tokens < 1:
|
171 |
-
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
172 |
-
return
|
173 |
-
|
174 |
-
thread = Thread(target=model.generate, kwargs=dict(
|
175 |
-
inputs=input_ids,
|
176 |
-
do_sample=do_sample,
|
177 |
-
temperature=temperature,
|
178 |
-
top_p=top_p,
|
179 |
-
max_new_tokens=max_new_tokens,
|
180 |
-
streamer=streamer,
|
181 |
-
stopping_criteria=[stopping_criteria],
|
182 |
-
use_cache=True,
|
183 |
-
**image_args
|
184 |
-
))
|
185 |
-
thread.start()
|
186 |
-
|
187 |
-
generated_text = ori_prompt
|
188 |
-
for new_text in streamer:
|
189 |
-
generated_text += new_text
|
190 |
-
if generated_text.endswith(stop_str):
|
191 |
-
generated_text = generated_text[:-len(stop_str)]
|
192 |
-
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
193 |
-
|
194 |
-
def generate_stream_gate(self, params):
|
195 |
-
try:
|
196 |
-
for x in self.generate_stream(params):
|
197 |
-
yield x
|
198 |
-
except ValueError as e:
|
199 |
-
print("Caught ValueError:", e)
|
200 |
-
ret = {
|
201 |
-
"text": server_error_msg,
|
202 |
-
"error_code": 1,
|
203 |
-
}
|
204 |
-
yield json.dumps(ret).encode() + b"\0"
|
205 |
-
except torch.cuda.CudaError as e:
|
206 |
-
print("Caught torch.cuda.CudaError:", e)
|
207 |
-
ret = {
|
208 |
-
"text": server_error_msg,
|
209 |
-
"error_code": 1,
|
210 |
-
}
|
211 |
-
yield json.dumps(ret).encode() + b"\0"
|
212 |
-
except Exception as e:
|
213 |
-
print("Caught Unknown Error", e)
|
214 |
-
ret = {
|
215 |
-
"text": server_error_msg,
|
216 |
-
"error_code": 1,
|
217 |
-
}
|
218 |
-
yield json.dumps(ret).encode() + b"\0"
|
219 |
-
|
220 |
-
|
221 |
-
app = FastAPI()
|
222 |
-
|
223 |
-
|
224 |
-
def release_model_semaphore(fn=None):
|
225 |
-
model_semaphore.release()
|
226 |
-
if fn is not None:
|
227 |
-
fn()
|
228 |
-
|
229 |
-
|
230 |
-
@app.post("/worker_generate_stream")
|
231 |
-
async def generate_stream(request: Request):
|
232 |
-
global model_semaphore, global_counter
|
233 |
-
global_counter += 1
|
234 |
-
params = await request.json()
|
235 |
-
|
236 |
-
if model_semaphore is None:
|
237 |
-
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
238 |
-
await model_semaphore.acquire()
|
239 |
-
worker.send_heart_beat()
|
240 |
-
generator = worker.generate_stream_gate(params)
|
241 |
-
background_tasks = BackgroundTasks()
|
242 |
-
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
243 |
-
return StreamingResponse(generator, background=background_tasks)
|
244 |
-
|
245 |
-
|
246 |
-
@app.post("/worker_get_status")
|
247 |
-
async def get_status(request: Request):
|
248 |
-
return worker.get_status()
|
249 |
-
|
250 |
-
|
251 |
-
if __name__ == "__main__":
|
252 |
-
parser = argparse.ArgumentParser()
|
253 |
-
parser.add_argument("--host", type=str, default="localhost")
|
254 |
-
parser.add_argument("--port", type=int, default=21002)
|
255 |
-
parser.add_argument("--worker-address", type=str,
|
256 |
-
default="http://localhost:21002")
|
257 |
-
parser.add_argument("--controller-address", type=str,
|
258 |
-
default="http://localhost:21001")
|
259 |
-
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
260 |
-
parser.add_argument("--model-base", type=str, default=None)
|
261 |
-
parser.add_argument("--model-name", type=str)
|
262 |
-
parser.add_argument("--device", type=str, default="cuda")
|
263 |
-
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
|
264 |
-
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
265 |
-
parser.add_argument("--stream-interval", type=int, default=1)
|
266 |
-
parser.add_argument("--no-register", action="store_true")
|
267 |
-
parser.add_argument("--load-8bit", action="store_true")
|
268 |
-
parser.add_argument("--load-4bit", action="store_true")
|
269 |
-
args = parser.parse_args()
|
270 |
-
logger.info(f"args: {args}")
|
271 |
-
|
272 |
-
if args.multi_modal:
|
273 |
-
logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
|
274 |
-
|
275 |
-
worker = ModelWorker(args.controller_address,
|
276 |
-
args.worker_address,
|
277 |
-
worker_id,
|
278 |
-
args.no_register,
|
279 |
-
args.model_path,
|
280 |
-
args.model_base,
|
281 |
-
args.model_name,
|
282 |
-
args.load_8bit,
|
283 |
-
args.load_4bit,
|
284 |
-
args.device)
|
285 |
-
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/serve/register_worker.py
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Manually register workers.
|
3 |
-
|
4 |
-
Usage:
|
5 |
-
python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
|
6 |
-
"""
|
7 |
-
|
8 |
-
import argparse
|
9 |
-
|
10 |
-
import requests
|
11 |
-
|
12 |
-
if __name__ == "__main__":
|
13 |
-
parser = argparse.ArgumentParser()
|
14 |
-
parser.add_argument("--controller-address", type=str)
|
15 |
-
parser.add_argument("--worker-name", type=str)
|
16 |
-
parser.add_argument("--check-heart-beat", action="store_true")
|
17 |
-
args = parser.parse_args()
|
18 |
-
|
19 |
-
url = args.controller_address + "/register_worker"
|
20 |
-
data = {
|
21 |
-
"worker_name": args.worker_name,
|
22 |
-
"check_heart_beat": args.check_heart_beat,
|
23 |
-
"worker_status": None,
|
24 |
-
}
|
25 |
-
r = requests.post(url, json=data)
|
26 |
-
assert r.status_code == 200
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLaVA-Med/llava/serve/test_message.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import json
|
3 |
-
|
4 |
-
import requests
|
5 |
-
|
6 |
-
from llava.conversation import conv_templates
|
7 |
-
|
8 |
-
|
9 |
-
def main():
|
10 |
-
if args.worker_address:
|
11 |
-
worker_addr = args.worker_address
|
12 |
-
else:
|
13 |
-
controller_addr = args.controller_address
|
14 |
-
ret = requests.post(controller_addr + "/refresh_all_workers")
|
15 |
-
ret = requests.post(controller_addr + "/list_models")
|
16 |
-
models = ret.json()["models"]
|
17 |
-
models.sort()
|
18 |
-
print(f"Models: {models}")
|
19 |
-
|
20 |
-
ret = requests.post(controller_addr + "/get_worker_address",
|
21 |
-
json={"model": args.model_name})
|
22 |
-
worker_addr = ret.json()["address"]
|
23 |
-
print(f"worker_addr: {worker_addr}")
|
24 |
-
|
25 |
-
if worker_addr == "":
|
26 |
-
return
|
27 |
-
|
28 |
-
conv = conv_templates["mistral_instruct"].copy()
|
29 |
-
conv.append_message(conv.roles[0], args.message)
|
30 |
-
prompt = conv.get_prompt()
|
31 |
-
|
32 |
-
headers = {"User-Agent": "LLaVA Client"}
|
33 |
-
pload = {
|
34 |
-
"model": args.model_name,
|
35 |
-
"prompt": prompt,
|
36 |
-
"max_new_tokens": args.max_new_tokens,
|
37 |
-
"temperature": 0.7,
|
38 |
-
"stop": conv.sep2,
|
39 |
-
}
|
40 |
-
response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
|
41 |
-
json=pload, stream=True)
|
42 |
-
|
43 |
-
print(prompt, end="")
|
44 |
-
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
45 |
-
if chunk:
|
46 |
-
data = json.loads(chunk.decode("utf-8"))
|
47 |
-
output = data["text"].split("[/INST]")[-1]
|
48 |
-
print(output, end="\r")
|
49 |
-
print("")
|
50 |
-
|
51 |
-
|
52 |
-
if __name__ == "__main__":
|
53 |
-
parser = argparse.ArgumentParser()
|
54 |
-
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
|
55 |
-
parser.add_argument("--worker-address", type=str)
|
56 |
-
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
57 |
-
parser.add_argument("--max-new-tokens", type=int, default=256)
|
58 |
-
parser.add_argument("--message", type=str, default=
|
59 |
-
"Tell me a story with more than 1000 words.")
|
60 |
-
args = parser.parse_args()
|
61 |
-
|
62 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|