shivam1707 commited on
Commit
d715433
·
verified ·
1 Parent(s): 5fdc4db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -25
app.py CHANGED
@@ -1,27 +1,100 @@
1
  import gradio as gr
2
- from PIL import Image
3
- from transformers import AutoModel, AutoTokenizer
4
-
5
- # Load the VLLM model and tokenizer
6
- model_name = "mistralai/Pixtral-12B-2409"
7
- model = AutoModel.from_pretrained(model_name)
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
-
10
- # Define a function to generate text from an image
11
- def generate_text(image):
12
- inputs = tokenizer(image, return_tensors="pt")
13
- outputs = model.generate(inputs["input_ids"], num_beams=4, no_repeat_ngram_size=2)
14
- text = outputs[0].cpu().numpy().decode("utf-8")
15
- return text
16
-
17
- # Create a Gradio interface
18
- demo = gr.Interface(
19
- fn=generate_text,
20
- inputs=["image"],
21
- outputs=["text"],
22
- title="Image-to-Text Demo",
23
- description="Generate text from an image using a VLLM model",
24
- )
25
-
26
- # Launch the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  demo.launch()
 
1
  import gradio as gr
2
+ import spaces
3
+ from mistral_inference.transformer import Transformer
4
+ from mistral_inference.generate import generate
5
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
6
+ from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
7
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
8
+ from huggingface_hub import snapshot_download
9
+ from pathlib import Path
10
+
11
+ # モデルのダウンロードと準備
12
+ mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral')
13
+ mistral_models_path.mkdir(parents=True, exist_ok=True)
14
+
15
+ snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
16
+ allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
17
+ local_dir=mistral_models_path)
18
+
19
+ # トークナイザーとモデルのロード
20
+ tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
21
+ model = Transformer.from_folder(mistral_models_path)
22
+
23
+ # 推論処理
24
+ @spaces.GPU
25
+ def mistral_inference(prompt, image_url):
26
+ completion_request = ChatCompletionRequest(
27
+ messages=[UserMessage(content=[ImageURLChunk(image_url=image_url), TextChunk(text=prompt)])]
28
+ )
29
+
30
+ encoded = tokenizer.encode_chat_completion(completion_request)
31
+ images = encoded.images
32
+ tokens = encoded.tokens
33
+
34
+ out_tokens, _ = generate([tokens], model, images=[images], max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
35
+ result = tokenizer.decode(out_tokens[0])
36
+
37
+ return result
38
+
39
+ # 言語によるUIラベルの設定
40
+ def get_labels(language):
41
+ labels = {
42
+ 'en': {
43
+ 'title': "Pixtral Model Image Description",
44
+ 'text_prompt': "Text Prompt",
45
+ 'image_url': "Image URL",
46
+ 'output': "Model Output",
47
+ 'image_display': "Input Image",
48
+ 'submit': "Run Inference"
49
+ },
50
+ 'zh': {
51
+ 'title': "Pixtral模型图像描述",
52
+ 'text_prompt': "文本提示",
53
+ 'image_url': "图片网址",
54
+ 'output': "模型输出",
55
+ 'image_display': "输入图片",
56
+ 'submit': "运行推理"
57
+ },
58
+ 'jp': {
59
+ 'title': "Pixtralモデルによる画像説明生成",
60
+ 'text_prompt': "テキストプロンプト",
61
+ 'image_url': "画像URL",
62
+ 'output': "モデルの出力結果",
63
+ 'image_display': "入力された画像",
64
+ 'submit': "推論を実行"
65
+ }
66
+ }
67
+ return labels[language]
68
+
69
+ # Gradioインターフェース
70
+ def process_input(text, image_url):
71
+ result = mistral_inference(text, image_url)
72
+ return result, f'<img src="{image_url}" alt="Input Image" width="300">'
73
+
74
+ def update_ui(language):
75
+ labels = get_labels(language)
76
+ return labels['title'], labels['text_prompt'], labels['image_url'], labels['output'], labels['image_display'], labels['submit']
77
+
78
+ with gr.Blocks() as demo:
79
+ language_choice = gr.Dropdown(choices=['en', 'zh', 'jp'], label="Select Language", value='en')
80
+
81
+ title = gr.Markdown("## Pixtral Model Image Description")
82
+ with gr.Row():
83
+ text_input = gr.Textbox(label="Text Prompt", placeholder="e.g. Describe the image.")
84
+ image_input = gr.Textbox(label="Image URL", placeholder="e.g. https://example.com/image.png")
85
+
86
+ result_output = gr.Textbox(label="Model Output", lines=8, max_lines=20) # 高さ500ピクセルに相当するように調整
87
+ image_output = gr.HTML(label="Input Image") # 入力画像URLを表示するための場所
88
+
89
+ submit_button = gr.Button("Run Inference")
90
+
91
+ submit_button.click(process_input, inputs=[text_input, image_input], outputs=[result_output, image_output])
92
+
93
+ # 言語変更時にUIラベルを更新
94
+ language_choice.change(
95
+ fn=update_ui,
96
+ inputs=[language_choice],
97
+ outputs=[title, text_input, image_input, result_output, image_output, submit_button]
98
+ )
99
+
100
  demo.launch()