Ask for guilding batch inference

#15
by nguyen-brat - opened

I read the README and tried something similar with the Qwen2-vl method to run the model in batch input, but it did not work. Can you please show me how to do it?

Anyone who is seeking batch input inference can have a glance at my solution.:

from PIL import Image
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "mistral-community/pixtral-12b"
model = LlavaForConditionalGeneration.from_pretrained(
            model_id,
            device_map="auto",
            # attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True
        ).eval()
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.add_special_tokens({'pad_token': '<pad>'})
img = "../data/warn_up/warmup-images/bc24654fb4fba69b41b6b4dce15295fc4acc8ebce9b9bff452ef6a8890e04e72.jpg"
img = Image.open(img)

chat = [
    {
      "role": "user", "content": [
        {"type": "image"}, 
        {"type": "text", "content": "based on the text in this image explain why this image contain sarcasm meaning ? based on the text in this image explain why this image contain sarcasm meaning ?"}, 
      ]
    }
]

chat_2 = [
    {
      "role": "user", "content": [
        {"type": "image"}, 
        {"type": "text", "content": "based "}, 
      ]
    }
]

prompt = processor.apply_chat_template(chat)
prompt_2 = processor.apply_chat_template(chat_2)
inputs = processor(text=[prompt, prompt_2], images=[[img], [img]], return_tensors="pt", padding=True,
        padding_side="left",).to(model.device)

with torch.no_grad():
    generate_ids = model.generate(**inputs, max_new_tokens=500)
output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(output)

Have tested it and it works ok.

@nguyen-brat Hey it doesnt work I tried the same code and this is the error that comes and it is coming every single time
ValueError: Image features and image tokens do not match: tokens: 645, features 1290

just its variations whats happening here?

@manan05 you can try my anaconda environment to see if that error still pop up because i can run ok on my computer:

name: pixtral
channels:
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - asttokens=2.4.1=pyhd8ed1ab_0
  - bzip2=1.0.8=h5eee18b_6
  - ca-certificates=2024.9.24=h06a4308_0
  - comm=0.2.2=pyhd8ed1ab_0
  - debugpy=1.8.7=py310hf71b8c6_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - exceptiongroup=1.2.2=pyhd8ed1ab_0
  - executing=2.1.0=pyhd8ed1ab_0
  - importlib-metadata=8.5.0=pyha770c72_0
  - ipykernel=6.29.5=pyh3099207_0
  - ipython=8.28.0=pyh707e725_0
  - jedi=0.19.1=pyhd8ed1ab_0
  - jupyter_client=8.6.3=pyhd8ed1ab_0
  - jupyter_core=5.7.2=pyh31011fe_1
  - krb5=1.21.3=h143b758_0
  - ld_impl_linux-64=2.40=h12ee557_0
  - libedit=3.1.20230828=h5eee18b_0
  - libffi=3.4.2=h7f98852_5
  - libgcc=14.2.0=h77fa898_1
  - libgcc-ng=14.2.0=h69a702a_1
  - libgomp=14.2.0=h77fa898_1
  - libnsl=2.0.1=hd590300_0
  - libsodium=1.0.20=h4ab18f5_0
  - libsqlite=3.46.1=hadc24fc_0
  - libstdcxx=14.2.0=hc0a3c3a_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=2.38.1=h0b41bf4_0
  - libxcrypt=4.4.36=hd590300_1
  - libzlib=1.3.1=hb9d3cd8_2
  - matplotlib-inline=0.1.7=pyhd8ed1ab_0
  - ncurses=6.5=he02047a_1
  - nest-asyncio=1.6.0=pyhd8ed1ab_0
  - openssl=3.3.2=hb9d3cd8_0
  - packaging=24.1=pyhd8ed1ab_0
  - parso=0.8.4=pyhd8ed1ab_0
  - pexpect=4.9.0=pyhd8ed1ab_0
  - pickleshare=0.7.5=py_1003
  - pip=24.2=py310h06a4308_0
  - platformdirs=4.3.6=pyhd8ed1ab_0
  - prompt-toolkit=3.0.48=pyha770c72_0
  - psutil=6.0.0=py310ha75aee5_2
  - ptyprocess=0.7.0=pyhd3deb0d_0
  - pure_eval=0.2.3=pyhd8ed1ab_0
  - pygments=2.18.0=pyhd8ed1ab_0
  - python=3.10.15=h4a871b0_2_cpython
  - python_abi=3.10=5_cp310
  - pyzmq=26.2.0=py310h71f11fc_3
  - readline=8.2=h5eee18b_0
  - setuptools=75.1.0=py310h06a4308_0
  - six=1.16.0=pyh6c4a22f_0
  - sqlite=3.46.1=h9eae976_0
  - stack_data=0.6.2=pyhd8ed1ab_0
  - tk=8.6.13=noxft_h4845f30_101
  - tornado=6.4.1=py310ha75aee5_1
  - traitlets=5.14.3=pyhd8ed1ab_0
  - typing_extensions=4.12.2=pyha770c72_0
  - wcwidth=0.2.13=pyhd8ed1ab_0
  - wheel=0.44.0=py310h06a4308_0
  - xz=5.4.6=h5eee18b_1
  - zeromq=4.3.5=h3b0a872_6
  - zipp=3.20.2=pyhd8ed1ab_0
  - zlib=1.3.1=hb9d3cd8_2
  - pip:
      - accelerate==1.0.1
      - aiohappyeyeballs==2.4.3
      - aiohttp==3.10.10
      - aiosignal==1.3.1
      - annotated-types==0.7.0
      - anyio==4.6.2.post1
      - async-timeout==4.0.3
      - attrs==24.2.0
      - av==13.1.0
      - certifi==2024.8.30
      - charset-normalizer==3.4.0
      - click==8.1.7
      - cloudpickle==3.1.0
      - contourpy==1.3.0
      - cycler==0.12.1
      - datasets==3.0.1
      - dill==0.3.8
      - diskcache==5.6.3
      - distro==1.9.0
      - docstring-parser==0.16
      - einops==0.8.0
      - fastapi==0.115.2
      - filelock==3.16.1
      - fire==0.7.0
      - flash-attn==2.6.3
      - fonttools==4.54.1
      - frozenlist==1.4.1
      - fsspec==2024.6.1
      - gguf==0.10.0
      - h11==0.14.0
      - httpcore==1.0.6
      - httptools==0.6.4
      - httpx==0.27.2
      - huggingface-hub==0.25.2
      - idna==3.10
      - interegular==0.3.3
      - jinja2==3.1.4
      - jiter==0.6.1
      - jsonschema==4.23.0
      - jsonschema-specifications==2024.10.1
      - kiwisolver==1.4.7
      - lark==1.2.2
      - llvmlite==0.43.0
      - lm-format-enforcer==0.10.6
      - markupsafe==3.0.1
      - matplotlib==3.9.2
      - mistral-common==1.4.4
      - mistral-inference==1.5.0
      - mpmath==1.3.0
      - msgpack==1.1.0
      - msgspec==0.18.6
      - multidict==6.1.0
      - multiprocess==0.70.16
      - networkx==3.4.1
      - numba==0.60.0
      - numpy==1.26.4
      - nvidia-cublas-cu12==12.1.3.1
      - nvidia-cuda-cupti-cu12==12.1.105
      - nvidia-cuda-nvrtc-cu12==12.1.105
      - nvidia-cuda-runtime-cu12==12.1.105
      - nvidia-cudnn-cu12==9.1.0.70
      - nvidia-cufft-cu12==11.0.2.54
      - nvidia-curand-cu12==10.3.2.106
      - nvidia-cusolver-cu12==11.4.5.107
      - nvidia-cusparse-cu12==12.1.0.106
      - nvidia-ml-py==12.560.30
      - nvidia-nccl-cu12==2.20.5
      - nvidia-nvjitlink-cu12==12.6.77
      - nvidia-nvtx-cu12==12.1.105
      - openai==1.51.2
      - opencv-python-headless==4.10.0.84
      - outlines==0.0.46
      - pandas==2.2.3
      - partial-json-parser==0.2.1.1.post4
      - pillow==10.4.0
      - prometheus-client==0.21.0
      - prometheus-fastapi-instrumentator==7.0.0
      - propcache==0.2.0
      - protobuf==5.28.2
      - py-cpuinfo==9.0.0
      - pyairports==2.1.1
      - pyarrow==17.0.0
      - pycountry==24.6.1
      - pydantic==2.9.2
      - pydantic-core==2.23.4
      - pyparsing==3.2.0
      - python-dateutil==2.9.0.post0
      - python-dotenv==1.0.1
      - pytz==2024.2
      - pyyaml==6.0.2
      - qwen-vl-utils==0.0.8
      - ray==2.37.0
      - referencing==0.35.1
      - regex==2024.9.11
      - requests==2.32.3
      - rpds-py==0.20.0
      - safetensors==0.4.5
      - sentencepiece==0.2.0
      - simple-parsing==0.1.6
      - sniffio==1.3.1
      - starlette==0.40.0
      - sympy==1.13.3
      - termcolor==2.5.0
      - tiktoken==0.7.0
      - tokenizers==0.20.1
      - torch==2.4.0
      - torchvision==0.19.0
      - tqdm==4.66.5
      - transformers==4.45.2
      - triton==3.0.0
      - tzdata==2024.2
      - urllib3==2.2.3
      - uvicorn==0.32.0
      - uvloop==0.21.0
      - vllm==0.6.3
      - watchfiles==0.24.0
      - websockets==13.1
      - xformers==0.0.27.post2
      - xxhash==3.5.0
      - yarl==1.15.4
prefix: /home/nguyen/miniconda3/envs/pixtral

@nguyen-brat Hey can you send the image you are using so we can exactly pin point the problem?

The error appears only when using images of different size for each example in the batch, since the model is not designed for processing multiple prompts at the same time

Sign up or log in to comment