Aekanun commited on
Commit
1c8a6bd
·
1 Parent(s): a187193

fixed app.py with specific model type

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import warnings
3
  import torch
4
  import gc
5
- from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
6
  from PIL import Image
7
  import gradio as gr
8
  from huggingface_hub import login
@@ -35,7 +35,6 @@ def load_model_and_processor():
35
 
36
  try:
37
  # Model paths
38
- base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
39
  hub_model_path = "Aekanun/thai-handwriting-llm"
40
 
41
  # BitsAndBytes config
@@ -45,14 +44,18 @@ def load_model_and_processor():
45
  bnb_4bit_quant_type="nf4",
46
  bnb_4bit_compute_dtype=torch.bfloat16
47
  )
 
 
 
 
48
 
49
- # Load processor from base model
50
- processor = AutoProcessor.from_pretrained(base_model_path)
51
 
52
- # Load model from Hub
53
  print("กำลังโหลดโมเดลจาก Hub...")
54
  model = AutoModelForVision2Seq.from_pretrained(
55
  hub_model_path,
 
56
  device_map="auto",
57
  torch_dtype=torch.bfloat16,
58
  quantization_config=bnb_config,
@@ -76,6 +79,10 @@ def process_handwriting(image):
76
  # Ensure image is in PIL format
77
  if not isinstance(image, Image.Image):
78
  image = Image.fromarray(image)
 
 
 
 
79
 
80
  # Create prompt
81
  prompt = """Transcribe the Thai handwritten text from the provided image.
 
2
  import warnings
3
  import torch
4
  import gc
5
+ from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig, AutoConfig
6
  from PIL import Image
7
  import gradio as gr
8
  from huggingface_hub import login
 
35
 
36
  try:
37
  # Model paths
 
38
  hub_model_path = "Aekanun/thai-handwriting-llm"
39
 
40
  # BitsAndBytes config
 
44
  bnb_4bit_quant_type="nf4",
45
  bnb_4bit_compute_dtype=torch.bfloat16
46
  )
47
+
48
+ # Load model configuration
49
+ config = AutoConfig.from_pretrained(hub_model_path, trust_remote_code=True)
50
+ config.model_type = "llava" # กำหนด model_type
51
 
52
+ # Load processor and model
53
+ processor = AutoProcessor.from_pretrained(hub_model_path, trust_remote_code=True)
54
 
 
55
  print("กำลังโหลดโมเดลจาก Hub...")
56
  model = AutoModelForVision2Seq.from_pretrained(
57
  hub_model_path,
58
+ config=config,
59
  device_map="auto",
60
  torch_dtype=torch.bfloat16,
61
  quantization_config=bnb_config,
 
79
  # Ensure image is in PIL format
80
  if not isinstance(image, Image.Image):
81
  image = Image.fromarray(image)
82
+
83
+ # Convert to RGB if needed
84
+ if image.mode != "RGB":
85
+ image = image.convert("RGB")
86
 
87
  # Create prompt
88
  prompt = """Transcribe the Thai handwritten text from the provided image.