mychen76 commited on
Commit
55f20f4
1 Parent(s): fe72180

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +43 -4
README.md CHANGED
@@ -49,12 +49,16 @@ from transformers import AutoProcessor
49
 
50
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51
  dtype = torch.bfloat16
 
52
 
53
- ## input
 
54
  url = "https://huggingface.co/datasets/mychen76/medtrinity_brain_30k_hf/viewer/default/train?row=4&image-viewer=image-62-2B87111BBD996B48DB4C86B0244653FF84B3B8A9"
55
  image = Image.open(requests.get(url, stream=True).raw)
 
56
 
57
- ## load model
 
58
  FINETUNED_MODEL_ID="mychen76/paligemma-3b-mix-448-med_30k-ct-brain"
59
 
60
  processor = AutoProcessor.from_pretrained(FINETUNED_MODEL_ID)
@@ -64,7 +68,7 @@ model = PaliGemmaForConditionalGeneration.from_pretrained(
64
  device_map=device
65
  ).eval()
66
  ```
67
- run inference
68
  ```
69
  # Instruct the model to create a caption in Spanish
70
  def run_inference(input_text,input_image, model, processor,max_tokens=1024):
@@ -84,11 +88,46 @@ input_text="caption"
84
  pred_text = run_inference(input_text,input_image,model, processor)
85
  print(pred_text)
86
  ```
87
- result
88
  ```
89
  The image is a CT scan of the brain, showing various brain structures without the presence of medical devices. The region of interest, located centrally and in the middle of the image, occupies approximately 3.0% of the area and appears to have an abnormal texture or density compared to the surrounding brain tissue, which may indicate a pathological condition. This abnormal area could be related to the surrounding brain structures, potentially affecting them or being affected by a shared pathological process, such as a hemorrhage or a mass effect.
90
  ```
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  ### Direct Use
94
 
 
49
 
50
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51
  dtype = torch.bfloat16
52
+ ```
53
 
54
+ ***input***
55
+ ```
56
  url = "https://huggingface.co/datasets/mychen76/medtrinity_brain_30k_hf/viewer/default/train?row=4&image-viewer=image-62-2B87111BBD996B48DB4C86B0244653FF84B3B8A9"
57
  image = Image.open(requests.get(url, stream=True).raw)
58
+ ```
59
 
60
+ ***load model***
61
+ ```
62
  FINETUNED_MODEL_ID="mychen76/paligemma-3b-mix-448-med_30k-ct-brain"
63
 
64
  processor = AutoProcessor.from_pretrained(FINETUNED_MODEL_ID)
 
68
  device_map=device
69
  ).eval()
70
  ```
71
+ ***run inference***
72
  ```
73
  # Instruct the model to create a caption in Spanish
74
  def run_inference(input_text,input_image, model, processor,max_tokens=1024):
 
88
  pred_text = run_inference(input_text,input_image,model, processor)
89
  print(pred_text)
90
  ```
91
+ ***result***
92
  ```
93
  The image is a CT scan of the brain, showing various brain structures without the presence of medical devices. The region of interest, located centrally and in the middle of the image, occupies approximately 3.0% of the area and appears to have an abnormal texture or density compared to the surrounding brain tissue, which may indicate a pathological condition. This abnormal area could be related to the surrounding brain structures, potentially affecting them or being affected by a shared pathological process, such as a hemorrhage or a mass effect.
94
  ```
95
 
96
+ ***Running on CUDA***
97
+
98
+ ```
99
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
100
+ from PIL import Image
101
+ import requests
102
+ import torch
103
+
104
+ FINETUNED_MODEL_ID="mychen76/paligemma-3b-mix-448-med_30k-ct-brain"
105
+ device = "cuda:0"
106
+ dtype = torch.bfloat16
107
+
108
+ url = "https://huggingface.co/datasets/mychen76/medtrinity_brain_30k_hf/viewer/default/train?row=4&image-viewer=image-62-2B87111BBD996B48DB4C86B0244653FF84B3B8A9"
109
+ image = Image.open(requests.get(url, stream=True).raw)
110
+
111
+ model = PaliGemmaForConditionalGeneration.from_pretrained(
112
+ FINETUNED_MODEL_ID,
113
+ torch_dtype=dtype,
114
+ device_map=device,
115
+ revision="bfloat16",
116
+ ).eval()
117
+ processor = AutoProcessor.from_pretrained(FINETUNED_MODEL_ID)
118
+
119
+ # Instruct the model to create a caption in Spanish
120
+ prompt = "caption es"
121
+ model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
122
+ input_len = model_inputs["input_ids"].shape[-1]
123
+
124
+ with torch.inference_mode():
125
+ generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
126
+ generation = generation[0][input_len:]
127
+ decoded = processor.decode(generation, skip_special_tokens=True)
128
+ print(decoded)
129
+ ```
130
+
131
 
132
  ### Direct Use
133