|
import gradio as gr |
|
from gradio.data_classes import FileData |
|
from huggingface_hub import snapshot_download |
|
from pathlib import Path |
|
import base64 |
|
import spaces |
|
import os |
|
|
|
from mistral_inference.transformer import Transformer |
|
from mistral_inference.generate import generate |
|
|
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer |
|
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, TextChunk, ImageURLChunk |
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest |
|
|
|
models_path = Path.home().joinpath('pixtral', 'Pixtral') |
|
models_path.mkdir(parents=True, exist_ok=True) |
|
|
|
snapshot_download(repo_id="mistral-community/pixtral-12b-240910", |
|
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], |
|
local_dir=models_path) |
|
|
|
tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json") |
|
model = Transformer.from_folder(models_path) |
|
|
|
def image_to_base64(image_path): |
|
with open(image_path, 'rb') as img: |
|
encoded_string = base64.b64encode(img.read()).decode('utf-8') |
|
return f"data:image/jpeg;base64,{encoded_string}" |
|
|
|
@spaces.GPU(duration=60) |
|
def run_inference(message, history): |
|
|
|
messages = [] |
|
images = [] |
|
for couple in history: |
|
if type(couple[0]) is tuple: |
|
images += couple[0] |
|
elif couple[0][1]: |
|
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(path)) for path in images]+[TextChunk(text=couple[0][1])])) |
|
messages.append(AssistantMessage(content = couple[1])) |
|
images = [] |
|
|
|
|
|
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(file["path"])) for file in message["files"]]+[TextChunk(text=message["text"])])) |
|
|
|
completion_request = ChatCompletionRequest(messages=messages) |
|
|
|
encoded = tokenizer.encode_chat_completion(completion_request) |
|
|
|
images = encoded.images |
|
tokens = encoded.tokens |
|
|
|
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) |
|
result = tokenizer.decode(out_tokens[0]) |
|
return result |
|
|
|
demo = gr.ChatInterface(fn=run_inference, title="Pixtral 12B", multimodal=True, description="A demo chat interface with Pixtral 12B, deployed using Mistral Inference.") |
|
demo.queue().launch() |