AlexandrosChariton's picture
Update README.md
00388f0 verified
metadata
library_name: peft
tags:
  - language
  - vision
  - multimodal
  - chess
  - chat
  - adapters
license: mit
base_model:
  - mistral-community/pixtral-12b
base_model_relation: adapter
pipeline_tag: visual-question-answering

Model Card for Model ID

This repo contains LoRA adapters for pixtral-12B trained on a chess dataset. The goal of the model is to parse images of chessboards and output the positions of the pieces.

Model Details

Was fine tuned for 3 epochs on rougly 1000 images of chessboards with roughly 250M/12B trainable parameters

Model Usage

Pass it a chessboard and ask it to describe it. Won't get it right every time but it is an improvement compared to the original pixtral for this task and fidelity will likely improve with the following versions.

For the image

image/png

We get the output:

Sure! Here is the description of the chess position in the image:

1. **White Pieces:**
   - **King:** Located on e2
   - **Rook:** Located on a6
   - **Pawn:** Located on h3

2. **Black Pieces:**
   - **King:** Located on e8
   - **Pawn:** Located on e5
   - **Pawn:** Located on g5

The board is mostly empty, with pieces only on the edges. The white king is on e2, the white rook is on a6, and the white pawn is on h3. For the black pieces, the king is on e8, the pawn is on e5, and another pawn is on g5.

Which is almost correct but there's room for improvement for more complex chess positions.

The code I use for inference downloads pixtral and runs the model with the adaptors in this repo. Here is the code (replace "cuda" with "mps" if you're on macbook)

from peft import PeftModel, PeftConfig
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import torch

# Load the model and processor
model_id = "mistral-community/pixtral-12b"
model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cuda")
processor = AutoProcessor.from_pretrained(model_id)

# Load the LoRA configuration
peft_config = PeftConfig.from_pretrained("AlexandrosChariton/Chess-pixtral-12B-Lora-v0")

# Apply the LoRA configuration to the base model
lora_model = PeftModel.from_pretrained(model, "AlexandrosChariton/Chess-pixtral-12B-Lora-v0")

# Load single image using PIL
image_path = "example_position.png"
image = Image.open(image_path)

# Chat template is applied to the prompt
PROMPT = "<s>[INST]Describe the chess position in the image, piece by piece.[IMG][/INST]"

# Pass single image instead of list of URLs
inputs = processor(text=PROMPT, images=image, return_tensors="pt").to("cuda")
generate_ids = lora_model.generate(**inputs, max_new_tokens=650)
output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(output)