Aekanun commited on
Commit
a187193
1 Parent(s): 1a517f1

fixed app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -43
app.py CHANGED
@@ -1,81 +1,124 @@
1
  import os
2
- from huggingface_hub import login
3
- from transformers import AutoProcessor, AutoModelForVision2Seq
4
  import torch
 
 
5
  from PIL import Image
6
  import gradio as gr
 
7
 
8
- # Login to Hugging Face Hub
9
- if 'HUGGING_FACE_HUB_TOKEN' in os.environ:
10
- print("Logging in to Hugging Face Hub...")
11
- login(token=os.environ['HUGGING_FACE_HUB_TOKEN'])
12
- else:
13
- print("Warning: HUGGING_FACE_HUB_TOKEN not found")
14
 
15
  # Global variables
16
  model = None
17
  processor = None
18
 
19
- def load_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  global model, processor
 
 
21
  try:
22
- model_path = "Aekanun/thai-handwriting-llm"
23
- print(f"Loading model and processor from {model_path}...")
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- processor = AutoProcessor.from_pretrained(model_path)
26
- model = AutoModelForVision2Seq.from_pretrained(model_path)
 
 
 
 
 
 
 
 
27
 
28
- if torch.cuda.is_available():
29
- model = model.to("cuda")
30
-
31
  return True
32
  except Exception as e:
33
- print(f"Error loading model: {str(e)}")
34
  return False
35
 
36
- def process_image(image):
 
 
 
37
  if image is None:
38
  return "กรุณาอัพโหลดรูปภาพ"
39
-
40
  try:
41
  # Ensure image is in PIL format
42
  if not isinstance(image, Image.Image):
43
  image = Image.fromarray(image)
44
 
45
- # Convert to RGB if needed
46
- if image.mode != "RGB":
47
- image = image.convert("RGB")
48
-
49
- # Process image
50
- inputs = processor(images=image, return_tensors="pt")
51
-
52
- # Move to GPU if available
53
- if torch.cuda.is_available():
54
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
55
-
56
- # Generate text
 
 
 
 
 
 
 
 
 
57
  with torch.no_grad():
58
  outputs = model.generate(
59
  **inputs,
60
- max_new_tokens=100,
61
- num_beams=4,
62
- pad_token_id=processor.tokenizer.pad_token_id,
63
- eos_token_id=processor.tokenizer.eos_token_id
64
  )
65
-
66
  # Decode output
67
- predicted_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
68
- return predicted_text.strip()
69
 
70
  except Exception as e:
71
  return f"เกิดข้อผิดพลาด: {str(e)}"
72
 
73
- # Initialize
74
- print("Initializing application...")
75
- if load_model():
76
  # Create Gradio interface
77
  demo = gr.Interface(
78
- fn=process_image,
79
  inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"),
80
  outputs=gr.Textbox(label="ข้อความที่แปลงได้"),
81
  title="Thai Handwriting Recognition",
@@ -86,4 +129,4 @@ if load_model():
86
  if __name__ == "__main__":
87
  demo.launch()
88
  else:
89
- print("Failed to initialize the application")
 
1
  import os
2
+ import warnings
 
3
  import torch
4
+ import gc
5
+ from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
6
  from PIL import Image
7
  import gradio as gr
8
+ from huggingface_hub import login
9
 
10
+ # Basic settings
11
+ warnings.filterwarnings('ignore')
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 
 
 
13
 
14
  # Global variables
15
  model = None
16
  processor = None
17
 
18
+ # Clear CUDA cache
19
+ if torch.cuda.is_available():
20
+ torch.cuda.empty_cache()
21
+ gc.collect()
22
+ print("เคลียร์ CUDA cache เรียบร้อยแล้ว")
23
+
24
+ # Login to Hugging Face Hub
25
+ if 'HUGGING_FACE_HUB_TOKEN' in os.environ:
26
+ print("กำลังเข้าสู่ระบบ Hugging Face Hub...")
27
+ login(token=os.environ['HUGGING_FACE_HUB_TOKEN'])
28
+ else:
29
+ print("คำเตือน: ไม่พบ HUGGING_FACE_HUB_TOKEN")
30
+
31
+ def load_model_and_processor():
32
+ """โหลดโมเดลและ processor"""
33
  global model, processor
34
+ print("กำลังโหลดโมเดลและ processor...")
35
+
36
  try:
37
+ # Model paths
38
+ base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
39
+ hub_model_path = "Aekanun/thai-handwriting-llm"
40
+
41
+ # BitsAndBytes config
42
+ bnb_config = BitsAndBytesConfig(
43
+ load_in_4bit=True,
44
+ bnb_4bit_use_double_quant=True,
45
+ bnb_4bit_quant_type="nf4",
46
+ bnb_4bit_compute_dtype=torch.bfloat16
47
+ )
48
+
49
+ # Load processor from base model
50
+ processor = AutoProcessor.from_pretrained(base_model_path)
51
 
52
+ # Load model from Hub
53
+ print("กำลังโหลดโมเดลจาก Hub...")
54
+ model = AutoModelForVision2Seq.from_pretrained(
55
+ hub_model_path,
56
+ device_map="auto",
57
+ torch_dtype=torch.bfloat16,
58
+ quantization_config=bnb_config,
59
+ trust_remote_code=True
60
+ )
61
+ print("โหลดโมเดลสำเร็จ!")
62
 
 
 
 
63
  return True
64
  except Exception as e:
65
+ print(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}")
66
  return False
67
 
68
+ def process_handwriting(image):
69
+ """ฟังก์ชันสำหรับ Gradio interface"""
70
+ global model, processor
71
+
72
  if image is None:
73
  return "กรุณาอัพโหลดรูปภาพ"
74
+
75
  try:
76
  # Ensure image is in PIL format
77
  if not isinstance(image, Image.Image):
78
  image = Image.fromarray(image)
79
 
80
+ # Create prompt
81
+ prompt = """Transcribe the Thai handwritten text from the provided image.
82
+ Only return the transcription in Thai language."""
83
+
84
+ # Create model inputs
85
+ messages = [
86
+ {
87
+ "role": "user",
88
+ "content": [
89
+ {"type": "text", "text": prompt},
90
+ {"type": "image", "image": image}
91
+ ],
92
+ }
93
+ ]
94
+
95
+ # Process with model
96
+ text = processor.apply_chat_template(messages, tokenize=False)
97
+ inputs = processor(text=text, images=image, return_tensors="pt")
98
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
99
+
100
+ # Generate
101
  with torch.no_grad():
102
  outputs = model.generate(
103
  **inputs,
104
+ max_new_tokens=256,
105
+ do_sample=False,
106
+ pad_token_id=processor.tokenizer.pad_token_id
 
107
  )
108
+
109
  # Decode output
110
+ transcription = processor.decode(outputs[0], skip_special_tokens=True)
111
+ return transcription.strip()
112
 
113
  except Exception as e:
114
  return f"เกิดข้อผิดพลาด: {str(e)}"
115
 
116
+ # Initialize application
117
+ print("กำลังเริ่มต้นแอปพลิเคชัน...")
118
+ if load_model_and_processor():
119
  # Create Gradio interface
120
  demo = gr.Interface(
121
+ fn=process_handwriting,
122
  inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"),
123
  outputs=gr.Textbox(label="ข้อความที่แปลงได้"),
124
  title="Thai Handwriting Recognition",
 
129
  if __name__ == "__main__":
130
  demo.launch()
131
  else:
132
+ print("ไม่สามารถเริ่มต้นแอปพลิเคชันได้")