Spaces:
Runtime error
Runtime error
anas-awadalla
commited on
Commit
•
0a405ca
1
Parent(s):
b630945
3b again
Browse files
app.py
CHANGED
@@ -54,13 +54,14 @@ with open("bad_words.txt", "r") as f:
|
|
54 |
model, image_processor, tokenizer = create_model_and_transforms(
|
55 |
clip_vision_encoder_pretrained="openai",
|
56 |
clip_vision_encoder_path="ViT-L-14",
|
57 |
-
lang_encoder_path="anas-awadalla/mpt-
|
58 |
-
tokenizer_path="anas-awadalla/mpt-
|
59 |
-
cross_attn_every_n_layers=
|
60 |
)
|
61 |
|
62 |
-
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-
|
63 |
model.load_state_dict(torch.load(checkpoint_path), strict=False)
|
|
|
64 |
model.eval()
|
65 |
|
66 |
def generate(
|
@@ -153,13 +154,13 @@ def generate(
|
|
153 |
# with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
154 |
output = model.generate(
|
155 |
vision_x=vision_x,
|
156 |
-
lang_x=input_ids
|
157 |
-
attention_mask=attention_mask
|
158 |
max_new_tokens=30,
|
159 |
num_beams=3,
|
160 |
-
do_sample=True,
|
161 |
-
temperature=0.3,
|
162 |
-
top_k=0,
|
163 |
)
|
164 |
|
165 |
gen_text = tokenizer.decode(
|
|
|
54 |
model, image_processor, tokenizer = create_model_and_transforms(
|
55 |
clip_vision_encoder_pretrained="openai",
|
56 |
clip_vision_encoder_path="ViT-L-14",
|
57 |
+
lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b-dolly",
|
58 |
+
tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b-dolly",
|
59 |
+
cross_attn_every_n_layers=1,
|
60 |
)
|
61 |
|
62 |
+
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b-langinstruct", "checkpoint.pt")
|
63 |
model.load_state_dict(torch.load(checkpoint_path), strict=False)
|
64 |
+
|
65 |
model.eval()
|
66 |
|
67 |
def generate(
|
|
|
154 |
# with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
155 |
output = model.generate(
|
156 |
vision_x=vision_x,
|
157 |
+
lang_x=input_ids,
|
158 |
+
attention_mask=attention_mask,
|
159 |
max_new_tokens=30,
|
160 |
num_beams=3,
|
161 |
+
# do_sample=True,
|
162 |
+
# temperature=0.3,
|
163 |
+
# top_k=0,
|
164 |
)
|
165 |
|
166 |
gen_text = tokenizer.decode(
|
open_flamingo/open_flamingo/src/factory.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
import open_clip
|
3 |
-
import torch
|
4 |
|
5 |
from .flamingo import Flamingo
|
6 |
from .flamingo_lm import FlamingoLMMixin
|
@@ -58,7 +57,8 @@ def create_model_and_transforms(
|
|
58 |
lang_encoder = AutoModelForCausalLM.from_pretrained(
|
59 |
lang_encoder_path,
|
60 |
local_files_only=use_local_files,
|
61 |
-
trust_remote_code=True
|
|
|
62 |
|
63 |
# hacks for MPT-1B, which doesn't have a get_input_embeddings method
|
64 |
if "mpt-1b-redpajama-200b" in lang_encoder_path:
|
@@ -79,7 +79,6 @@ def create_model_and_transforms(
|
|
79 |
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
80 |
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
81 |
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
82 |
-
lang_encoder.to(0)
|
83 |
|
84 |
model = Flamingo(
|
85 |
vision_encoder,
|
@@ -90,7 +89,8 @@ def create_model_and_transforms(
|
|
90 |
"width"
|
91 |
],
|
92 |
cross_attn_every_n_layers=cross_attn_every_n_layers,
|
93 |
-
**flamingo_kwargs
|
|
|
94 |
|
95 |
# Freeze all parameters
|
96 |
model.requires_grad_(False)
|
|
|
1 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
import open_clip
|
|
|
3 |
|
4 |
from .flamingo import Flamingo
|
5 |
from .flamingo_lm import FlamingoLMMixin
|
|
|
57 |
lang_encoder = AutoModelForCausalLM.from_pretrained(
|
58 |
lang_encoder_path,
|
59 |
local_files_only=use_local_files,
|
60 |
+
trust_remote_code=True,
|
61 |
+
)
|
62 |
|
63 |
# hacks for MPT-1B, which doesn't have a get_input_embeddings method
|
64 |
if "mpt-1b-redpajama-200b" in lang_encoder_path:
|
|
|
79 |
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
80 |
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
81 |
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
|
|
82 |
|
83 |
model = Flamingo(
|
84 |
vision_encoder,
|
|
|
89 |
"width"
|
90 |
],
|
91 |
cross_attn_every_n_layers=cross_attn_every_n_layers,
|
92 |
+
**flamingo_kwargs,
|
93 |
+
)
|
94 |
|
95 |
# Freeze all parameters
|
96 |
model.requires_grad_(False)
|
open_flamingo/open_flamingo/src/flamingo.py
CHANGED
@@ -212,7 +212,7 @@ class Flamingo(nn.Module):
|
|
212 |
with torch.no_grad():
|
213 |
vision_x = self.vision_encoder(vision_x)[1]
|
214 |
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
215 |
-
vision_x = self.perceiver(vision_x)
|
216 |
|
217 |
for layer in self.lang_encoder._get_decoder_layers():
|
218 |
layer.condition_vis_x(vision_x)
|
|
|
212 |
with torch.no_grad():
|
213 |
vision_x = self.vision_encoder(vision_x)[1]
|
214 |
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
215 |
+
vision_x = self.perceiver(vision_x)
|
216 |
|
217 |
for layer in self.lang_encoder._get_decoder_layers():
|
218 |
layer.condition_vis_x(vision_x)
|