|
--- |
|
library_name: peft |
|
tags: |
|
- language |
|
- vision |
|
- multimodal |
|
- chess |
|
- chat |
|
- adapters |
|
license: mit |
|
base_model: |
|
- mistral-community/pixtral-12b |
|
base_model_relation: adapter |
|
pipeline_tag: visual-question-answering |
|
--- |
|
|
|
# Model Card for Model ID |
|
|
|
This repo contains LoRA adapters for [pixtral-12B](https://huggingface.co/mistral-community/pixtral-12b) trained on a chess dataset. The goal of the model is to parse images of chessboards and output the positions of the pieces. |
|
|
|
|
|
## Model Details |
|
|
|
Was fine tuned for 3 epochs on rougly 1000 images of chessboards with roughly 250M/12B trainable parameters |
|
|
|
### Model Usage |
|
|
|
Pass it a chessboard and ask it to describe it. Won't get it right every time but it is an improvement compared to the original pixtral for this task and fidelity will likely improve with the following versions. |
|
|
|
For the image |
|
|
|
![image/png](https://cdn-uploads.huggingface.co/production/uploads/63da89db9f2687298a0acbfe/5Zyvn3XLsQi4ytUESZgpn.png) |
|
|
|
We get the output: |
|
|
|
``` |
|
Sure! Here is the description of the chess position in the image: |
|
|
|
1. **White Pieces:** |
|
- **King:** Located on e2 |
|
- **Rook:** Located on a6 |
|
- **Pawn:** Located on h3 |
|
|
|
2. **Black Pieces:** |
|
- **King:** Located on e8 |
|
- **Pawn:** Located on e5 |
|
- **Pawn:** Located on g5 |
|
|
|
The board is mostly empty, with pieces only on the edges. The white king is on e2, the white rook is on a6, and the white pawn is on h3. For the black pieces, the king is on e8, the pawn is on e5, and another pawn is on g5. |
|
``` |
|
|
|
Which is almost correct but there's room for improvement for more complex chess positions. |
|
|
|
The code I use for inference downloads pixtral and runs the model with the adaptors in this repo. Here is the code (replace "cuda" with "mps" if you're on macbook) |
|
```python |
|
from peft import PeftModel, PeftConfig |
|
from transformers import AutoProcessor, LlavaForConditionalGeneration |
|
from PIL import Image |
|
import torch |
|
|
|
# Load the model and processor |
|
model_id = "mistral-community/pixtral-12b" |
|
model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cuda") |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
# Load the LoRA configuration |
|
peft_config = PeftConfig.from_pretrained("AlexandrosChariton/Chess-pixtral-12B-Lora-v0") |
|
|
|
# Apply the LoRA configuration to the base model |
|
lora_model = PeftModel.from_pretrained(model, "AlexandrosChariton/Chess-pixtral-12B-Lora-v0") |
|
|
|
# Load single image using PIL |
|
image_path = "example_position.png" |
|
image = Image.open(image_path) |
|
|
|
# Chat template is applied to the prompt |
|
PROMPT = "<s>[INST]Describe the chess position in the image, piece by piece.[IMG][/INST]" |
|
|
|
# Pass single image instead of list of URLs |
|
inputs = processor(text=PROMPT, images=image, return_tensors="pt").to("cuda") |
|
generate_ids = lora_model.generate(**inputs, max_new_tokens=650) |
|
output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
print(output) |
|
``` |