Having issue for doing inference after fine tuning idefics2 using LoRA
#70
by
jxue005
- opened
Hi,
I am trying to do inference using
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics2ForConditionalGeneration
import torch
peft_model_id = "idefics2-finetuned"
processor = AutoProcessor.from_pretrained(peft_model_id)
# Define quantization config
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
)
# Load the base model with adapters on top
model = Idefics2ForConditionalGeneration.from_pretrained(
peft_model_id,
torch_dtype=torch.float16,
quantization_config=quantization_config,
)
def predict(test_example):
test_image = test_example["image"]
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Extract JSON."},
{"type": "image"},
]
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[test_image], return_tensors="pt").to("cuda")
# Generate token IDs
generated_ids = model.generate(**inputs, max_new_tokens=2048, temperature = 0.1,top_p = 0.95, do_sample = True)
# Decode back into text
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
generated_json = token2json(generated_texts[0])
return generated_json
But the inference speed is really slow, I tried to do model.merge_and_unload()
to reduce the latency issue, to do so, I need to do the following first
if USE_QLORA or USE_LORA:
lora_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0.1,
target_modules=".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$",
use_dora=False if USE_QLORA else True,
init_lora_weights="gaussian",
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
But I got error
ValueError: Target module ModuleDict(
(default): Dropout(p=0.1, inplace=False)
) is not supported. Currently, only the following modules are supported: `torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`.
So I am wondering if you have the same issue before.
Any updates on this issue? facing the same.