Aekanun commited on
Commit
76059b0
1 Parent(s): f9d68b0

Fix app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -96
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import warnings
3
  import torch
4
  import gc
5
- from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
6
  from PIL import Image
7
  import gradio as gr
8
  from huggingface_hub import login
@@ -17,115 +17,106 @@ processor = None
17
 
18
  # Clear CUDA cache
19
  if torch.cuda.is_available():
20
- torch.cuda.empty_cache()
21
- gc.collect()
22
  print("เคลียร์ CUDA cache เรียบร้อยแล้ว")
23
 
24
  # Login to Hugging Face Hub
25
  if 'HUGGING_FACE_HUB_TOKEN' in os.environ:
26
- print("กำลังเข้าสู่ระบบ Hugging Face Hub...")
27
- login(token=os.environ['HUGGING_FACE_HUB_TOKEN'])
28
  else:
29
- print("คำเตือน: ไม่พบ HUGGING_FACE_HUB_TOKEN")
30
 
31
  def load_model_and_processor():
32
- """โหลดโมเดลและ processor"""
33
- global model, processor
34
- print("กำลังโหลดโมเดลและ processor...")
35
- try:
36
- # Model paths
37
- base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
38
- hub_model_path = "Aekanun/thai-handwriting-llm"
39
 
40
- # BitsAndBytes config
41
- bnb_config = BitsAndBytesConfig(
42
- load_in_4bit=True,
43
- bnb_4bit_use_double_quant=True,
44
- bnb_4bit_quant_type="nf4",
45
- bnb_4bit_compute_dtype=torch.bfloat16
46
- )
47
 
48
- # Load processor from base model
49
- print("กำลังโหลด processor...")
50
- processor = AutoProcessor.from_pretrained(base_model_path, use_auth_token=True)
51
-
52
- # Load model from Hub
53
- print("กำลังโหลดโมเดลจาก Hub...")
54
- model = AutoModelForVision2Seq.from_pretrained(
55
- hub_model_path,
56
- device_map="auto",
57
- torch_dtype=torch.bfloat16,
58
- quantization_config=bnb_config,
59
- trust_remote_code=True,
60
- use_auth_token=True
61
- )
62
- print("โหลดโมเดลสำเร็จ!")
63
- return True
64
- except Exception as e:
65
- print(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}")
66
- return False
67
 
68
  def process_handwriting(image):
69
- """ฟังก์ชันสำหรับ Gradio interface"""
70
- global model, processor
71
-
72
- if image is None:
73
- return "กรุณาอัพโหลดรูปภาพ"
74
-
75
- try:
76
- # Ensure image is in PIL format
77
- if not isinstance(image, Image.Image):
78
- image = Image.fromarray(image)
79
-
80
- # Create prompt
81
- prompt = """Transcribe the Thai handwritten text from the provided image.
82
  Only return the transcription in Thai language."""
83
-
84
- # Create model inputs
85
- messages = [
86
- {
87
- "role": "user",
88
- "content": [
89
- {"type": "text", "text": prompt},
90
- {"type": "image", "image": image}
91
- ],
92
- }
93
- ]
94
-
95
- # Process with model
96
- text = processor.apply_chat_template(messages, tokenize=False)
97
- inputs = processor(text=text, images=image, return_tensors="pt")
98
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
99
-
100
- # Generate
101
- with torch.no_grad():
102
- outputs = model.generate(
103
- **inputs,
104
- max_new_tokens=256,
105
- do_sample=False,
106
- pad_token_id=processor.tokenizer.pad_token_id
107
- )
108
-
109
- # Decode output
110
- transcription = processor.decode(outputs[0], skip_special_tokens=True)
111
- return transcription.strip()
112
- except Exception as e:
113
- return f"เกิดข้อผิดพลาด: {str(e)}"
114
 
115
  # Initialize application
116
  print("กำลังเริ่มต้นแอปพลิเคชัน...")
117
  if load_model_and_processor():
118
- # Create Gradio interface
119
- demo = gr.Interface(
120
- fn=process_handwriting,
121
- inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"),
122
- outputs=gr.Textbox(label="ข้อความที่แปลงได้"),
123
- title="Thai Handwriting Recognition",
124
- description="อัพโหลดรูปภาพลายมือเขียนภาษาไทยเพื่อแปลงเป็นข้อความ",
125
- examples=[["example1.jpg"], ["example2.jpg"]]
126
- )
127
-
128
- if __name__ == "__main__":
129
- demo.launch()
130
  else:
131
- print("ไม่สามารถเริ่มต้นแอปพลิเคชันได้")
 
2
  import warnings
3
  import torch
4
  import gc
5
+ from transformers import AutoModelForVision2Seq, AutoProcessor
6
  from PIL import Image
7
  import gradio as gr
8
  from huggingface_hub import login
 
17
 
18
  # Clear CUDA cache
19
  if torch.cuda.is_available():
20
+ torch.cuda.empty_cache()
21
+ gc.collect()
22
  print("เคลียร์ CUDA cache เรียบร้อยแล้ว")
23
 
24
  # Login to Hugging Face Hub
25
  if 'HUGGING_FACE_HUB_TOKEN' in os.environ:
26
+ print("กำลังเข้าสู่ระบบ Hugging Face Hub...")
27
+ login(token=os.environ['HUGGING_FACE_HUB_TOKEN'])
28
  else:
29
+ print("คำเตือน: ไม่พบ HUGGING_FACE_HUB_TOKEN")
30
 
31
  def load_model_and_processor():
32
+ """โหลดโมเดลและ processor"""
33
+ global model, processor
34
+ print("กำลังโหลดโมเดลและ processor...")
35
+ try:
36
+ # Model paths
37
+ base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
38
+ hub_model_path = "Aekanun/thai-handwriting-llm"
39
 
40
+ # Load processor from base model
41
+ print("กำลังโหลด processor...")
42
+ processor = AutoProcessor.from_pretrained(base_model_path, use_auth_token=True)
 
 
 
 
43
 
44
+ # Load model from Hub
45
+ print("กำลังโหลดโมเดลจาก Hub...")
46
+ model = AutoModelForVision2Seq.from_pretrained(
47
+ hub_model_path,
48
+ device_map="auto",
49
+ torch_dtype=torch.bfloat16,
50
+ trust_remote_code=True,
51
+ use_auth_token=True
52
+ )
53
+ print("โหลดโมเดลสำเร็จ!")
54
+ return True
55
+ except Exception as e:
56
+ print(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}")
57
+ return False
 
 
 
 
 
58
 
59
  def process_handwriting(image):
60
+ """ฟังก์ชันสำหรับ Gradio interface"""
61
+ global model, processor
62
+
63
+ if image is None:
64
+ return "กรุณาอัพโหลดรูปภาพ"
65
+
66
+ try:
67
+ # Ensure image is in PIL format
68
+ if not isinstance(image, Image.Image):
69
+ image = Image.fromarray(image)
70
+
71
+ # Create prompt
72
+ prompt = """Transcribe the Thai handwritten text from the provided image.
73
  Only return the transcription in Thai language."""
74
+
75
+ # Create model inputs
76
+ messages = [
77
+ {
78
+ "role": "user",
79
+ "content": [
80
+ {"type": "text", "text": prompt},
81
+ {"type": "image", "image": image}
82
+ ],
83
+ }
84
+ ]
85
+
86
+ # Process with model
87
+ text = processor.apply_chat_template(messages, tokenize=False)
88
+ inputs = processor(text=text, images=image, return_tensors="pt")
89
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
90
+
91
+ # Generate
92
+ with torch.no_grad():
93
+ outputs = model.generate(
94
+ **inputs,
95
+ max_new_tokens=256,
96
+ do_sample=False,
97
+ pad_token_id=processor.tokenizer.pad_token_id
98
+ )
99
+
100
+ # Decode output
101
+ transcription = processor.decode(outputs[0], skip_special_tokens=True)
102
+ return transcription.strip()
103
+ except Exception as e:
104
+ return f"เกิดข้อผิดพลาด: {str(e)}"
105
 
106
  # Initialize application
107
  print("กำลังเริ่มต้นแอปพลิเคชัน...")
108
  if load_model_and_processor():
109
+ # Create Gradio interface
110
+ demo = gr.Interface(
111
+ fn=process_handwriting,
112
+ inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"),
113
+ outputs=gr.Textbox(label="ข้อความที่แปลงได้"),
114
+ title="Thai Handwriting Recognition",
115
+ description="อัพโหลดรูปภาพลายมือเขียนภาษาไทยเพื่อแปลงเป็นข้อความ",
116
+ examples=[["example1.jpg"], ["example2.jpg"]]
117
+ )
118
+
119
+ if __name__ == "__main__":
120
+ demo.launch()
121
  else:
122
+ print("ไม่สามารถเริ่มต้นแอปพลิเคชันได้")