Aekanun commited on
Commit
78744e1
·
1 Parent(s): 069ee6d
Files changed (1) hide show
  1. app.py +53 -108
app.py CHANGED
@@ -1,123 +1,68 @@
1
  import os
2
- import warnings
3
  import torch
4
- import gc
5
- from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
6
  from PIL import Image
7
  import gradio as gr
8
- from huggingface_hub import login
9
-
10
- warnings.filterwarnings('ignore')
11
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
12
-
13
- # Global variables
14
- model = None
15
- processor = None
16
 
17
- if torch.cuda.is_available():
18
- torch.cuda.empty_cache()
19
- gc.collect()
20
- print("เคลียร์ CUDA cache เรียบร้อยแล้ว")
 
21
 
22
- def load_model_and_processor():
23
- """โหลดโมเดลและ processor"""
24
- global model, processor
25
- print("กำลังโหลดโมเดลและ processor...")
26
 
27
- try:
28
- # กำหนด paths
29
- base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
30
- hub_model_path = "Aekanun/thai-handwriting-llm"
31
-
32
- # ตั้งค่า BitsAndBytes
33
- bnb_config = BitsAndBytesConfig(
34
- load_in_4bit=True,
35
- bnb_4bit_use_double_quant=True,
36
- bnb_4bit_quant_type="nf4",
37
- bnb_4bit_compute_dtype=torch.bfloat16
38
- )
39
-
40
- # โหลด processor จาก base model
41
- processor = AutoProcessor.from_pretrained(
42
- base_model_path,
43
- token=os.environ.get('HUGGING_FACE_HUB_TOKEN')
44
- )
45
-
46
- # โหลดโมเดลจาก Hub
47
- print("กำลังโหลดโมเดลจาก Hub...")
48
- model = AutoModelForVision2Seq.from_pretrained(
49
- hub_model_path,
50
- device_map="auto",
51
- torch_dtype=torch.bfloat16,
52
- quantization_config=bnb_config,
53
- token=os.environ.get('HUGGING_FACE_HUB_TOKEN')
54
- )
55
- print("โหลดโมเดลจาก Hub สำเร็จ!")
56
-
57
- return True
58
- except Exception as e:
59
- print(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}")
60
- return False
61
-
62
- def process_handwriting(image):
63
- """ฟังก์ชันสำหรับ Gradio interface"""
64
- global model, processor
65
 
 
 
 
 
 
66
  if image is None:
67
  return "กรุณาอัพโหลดรูปภาพ"
 
 
 
68
 
69
- try:
70
- # Ensure image is in PIL format
71
- if not isinstance(image, Image.Image):
72
- image = Image.fromarray(image)
73
-
74
- # Convert to RGB if needed
75
- if image.mode != "RGB":
76
- image = image.convert("RGB")
77
-
78
- prompt = """Transcribe the Thai handwritten text from the provided image.
79
- Only return the transcription in Thai language."""
80
-
81
- messages = [
82
- {
83
- "role": "user",
84
- "content": [
85
- {"type": "text", "text": prompt},
86
- {"type": "image", "image": image}
87
- ],
88
- }
89
- ]
90
-
91
- text = processor.apply_chat_template(messages, tokenize=False)
92
- inputs = processor(text=text, images=image, return_tensors="pt")
93
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
94
 
95
- with torch.no_grad():
96
- outputs = model.generate(
97
- **inputs,
98
- max_new_tokens=256,
99
- do_sample=False,
100
- pad_token_id=processor.tokenizer.pad_token_id
101
- )
 
 
 
 
102
 
103
- transcription = processor.decode(outputs[0], skip_special_tokens=True)
104
- return transcription.strip()
105
-
106
- except Exception as e:
107
- return f"เกิดข้อผิดพลาด: {str(e)}"
 
 
 
 
 
 
 
 
108
 
109
- # Initialize application
110
- print("กำลังเริ่มต้นแอปพลิเคชัน...")
111
- if load_model_and_processor():
112
- demo = gr.Interface(
113
- fn=process_handwriting,
114
- inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"),
115
- outputs=gr.Textbox(label="ข้อความที่แปลงได้"),
116
- title="Thai Handwriting Recognition",
117
- description="อัพโหลดรูปภาพลายมือเขียนภาษาไทยเพื่อแปลงเป็นข้อความ"
118
- )
119
 
120
- if __name__ == "__main__":
121
- demo.launch()
122
- else:
123
- print("ไม่สามารถเริ่มต้นแอปพลิเคชันได้")
 
1
  import os
 
2
  import torch
3
+ from transformers import AutoModelForVision2Seq, AutoProcessor
 
4
  from PIL import Image
5
  import gradio as gr
 
 
 
 
 
 
 
 
6
 
7
+ # Login to Hugging Face Hub
8
+ from huggingface_hub import login
9
+ token = os.environ.get('HUGGING_FACE_HUB_TOKEN')
10
+ if token:
11
+ login(token=token)
12
 
13
+ def load_model():
14
+ base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
15
+ hub_model_path = "Aekanun/thai-handwriting-llm"
 
16
 
17
+ processor = AutoProcessor.from_pretrained(base_model_path, token=token)
18
+ model = AutoModelForVision2Seq.from_pretrained(hub_model_path, token=token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ return model, processor
21
+
22
+ model, processor = load_model()
23
+
24
+ def process_image(image):
25
  if image is None:
26
  return "กรุณาอัพโหลดรูปภาพ"
27
+
28
+ if not isinstance(image, Image.Image):
29
+ image = Image.fromarray(image)
30
 
31
+ if image.mode != "RGB":
32
+ image = image.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ prompt = "Transcribe the Thai handwritten text from the provided image.\nOnly return the transcription in Thai language."
35
+
36
+ messages = [
37
+ {
38
+ "role": "user",
39
+ "content": [
40
+ {"type": "text", "text": prompt},
41
+ {"type": "image", "image": image}
42
+ ],
43
+ }
44
+ ]
45
 
46
+ text = processor.apply_chat_template(messages, tokenize=False)
47
+ inputs = processor(text=text, images=image, return_tensors="pt")
48
+
49
+ with torch.no_grad():
50
+ outputs = model.generate(
51
+ **inputs,
52
+ max_new_tokens=256,
53
+ do_sample=False,
54
+ pad_token_id=processor.tokenizer.pad_token_id
55
+ )
56
+
57
+ transcription = processor.decode(outputs[0], skip_special_tokens=True)
58
+ return transcription.strip()
59
 
60
+ demo = gr.Interface(
61
+ fn=process_image,
62
+ inputs=gr.Image(type="pil"),
63
+ outputs="text",
64
+ title="Thai Handwriting OCR",
65
+ )
 
 
 
 
66
 
67
+ if __name__ == "__main__":
68
+ demo.launch()