hiyouga commited on
Commit
a3e666b
·
verified ·
1 Parent(s): d1e95b9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +72 -3
README.md CHANGED
@@ -1,3 +1,72 @@
1
- ---
2
- license: gemma
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: gemma
3
+ datasets:
4
+ - BUAADreamer/llava-en-zh-2k
5
+ language:
6
+ - en
7
+ - zh
8
+ library_name: transformers
9
+ pipeline_tag: image-text-to-text
10
+ base_model: google/paligemma-3b-mix-448
11
+ inference: false
12
+ tags:
13
+ - paligemma
14
+ - llama-factory
15
+ - mllm
16
+ - vlm
17
+ ---
18
+
19
+ # PaliGemma-3B-Chat-v0.1
20
+
21
+ This model is fine-tuned from [google/paligemma-3b-mix-448](https://huggingface.co/google/paligemma-3b-mix-448) using [LLaMA Factory](https://github.com/hiyouga/LLaMA-Factory).
22
+
23
+ ![examples](examples_en.png)
24
+
25
+ ## Usage
26
+
27
+ ```python
28
+ import requests
29
+ import torch
30
+ from PIL import Image
31
+ from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextStreamer
32
+
33
+
34
+ model_id = "/raid/zhangrichong/data/zhengyw/llama_factory/models/paligemma-chat"
35
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
36
+ processor = AutoProcessor.from_pretrained(model_id)
37
+ model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
38
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
39
+
40
+ url = "hiyouga/PaliGemma-3B-Chat-v0.1"
41
+ image = Image.open(requests.get(url, stream=True).raw)
42
+ pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]
43
+
44
+ messages = [
45
+ {"role": "user", "content": "What is in this image?"}
46
+ ]
47
+ input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
48
+ image_token_id = tokenizer.convert_tokens_to_ids("<image>")
49
+ image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
50
+ input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)
51
+
52
+ generate_ids = model.generate(input_ids, pixel_values=pixel_values, streamer=streamer, max_new_tokens=50)
53
+ ```
54
+
55
+ ## Training procedure
56
+
57
+ ### Training hyperparameters
58
+
59
+ The following hyperparameters were used during training:
60
+ - learning_rate: 0.00001
61
+ - num_train_epochs: 3.0
62
+ - train_batch_size: 1
63
+ - gradient_accumulation_steps: 8
64
+ - total_train_batch_size: 16
65
+ - seed: 42
66
+ - lr_scheduler_type: cosine
67
+ - mixed_precision_training: bf16
68
+
69
+ ### Framework versions
70
+
71
+ - Pytorch 2.3.0
72
+ - Transformers 4.41.0