Aekanun commited on
Commit
948e2eb
·
1 Parent(s): 925d635

fixing app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -19
app.py CHANGED
@@ -2,12 +2,12 @@ import os
2
  import warnings
3
  import torch
4
  import gc
5
- from transformers import LlavaForConditionalGeneration, LlavaProcessor
6
  from PIL import Image
7
  import gradio as gr
8
  from huggingface_hub import login
9
 
10
- # Basic settings
11
  warnings.filterwarnings('ignore')
12
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
13
 
@@ -15,7 +15,7 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
15
  model = None
16
  processor = None
17
 
18
- # Clear CUDA cache
19
  if torch.cuda.is_available():
20
  torch.cuda.empty_cache()
21
  gc.collect()
@@ -34,24 +34,35 @@ def load_model_and_processor():
34
  print("กำลังโหลดโมเดลและ processor...")
35
 
36
  try:
37
- # Model paths
 
38
  hub_model_path = "Aekanun/thai-handwriting-llm"
39
 
40
- # Load processor and model directly using LLaVA classes
41
- processor = LlavaProcessor.from_pretrained(
42
- hub_model_path,
43
- trust_remote_code=True
 
 
 
 
 
 
 
 
44
  )
45
 
 
46
  print("กำลังโหลดโมเดลจาก Hub...")
47
- model = LlavaForConditionalGeneration.from_pretrained(
48
  hub_model_path,
49
  device_map="auto",
50
  torch_dtype=torch.bfloat16,
 
51
  trust_remote_code=True,
52
- load_in_4bit=True
53
  )
54
- print("โหลดโมเดลสำเร็จ!")
55
 
56
  return True
57
  except Exception as e:
@@ -73,12 +84,12 @@ def process_handwriting(image):
73
  # Convert to RGB if needed
74
  if image.mode != "RGB":
75
  image = image.convert("RGB")
76
-
77
- # Create prompt
78
  prompt = """Transcribe the Thai handwritten text from the provided image.
79
  Only return the transcription in Thai language."""
80
 
81
- # Create model inputs
82
  messages = [
83
  {
84
  "role": "user",
@@ -89,12 +100,12 @@ Only return the transcription in Thai language."""
89
  }
90
  ]
91
 
92
- # Process with model
93
  text = processor.apply_chat_template(messages, tokenize=False)
94
  inputs = processor(text=text, images=image, return_tensors="pt")
95
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
96
 
97
- # Generate
98
  with torch.no_grad():
99
  outputs = model.generate(
100
  **inputs,
@@ -103,7 +114,7 @@ Only return the transcription in Thai language."""
103
  pad_token_id=processor.tokenizer.pad_token_id
104
  )
105
 
106
- # Decode output
107
  transcription = processor.decode(outputs[0], skip_special_tokens=True)
108
  return transcription.strip()
109
 
@@ -126,5 +137,4 @@ if load_model_and_processor():
126
  if __name__ == "__main__":
127
  demo.launch()
128
  else:
129
- print("ไม่สามารถเริ่มต้นแอปพลิเคชันได้")
130
-
 
2
  import warnings
3
  import torch
4
  import gc
5
+ from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
6
  from PIL import Image
7
  import gradio as gr
8
  from huggingface_hub import login
9
 
10
+ # ตั้งค่าพื้นฐาน
11
  warnings.filterwarnings('ignore')
12
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
13
 
 
15
  model = None
16
  processor = None
17
 
18
+ # เคลียร์ CUDA cache
19
  if torch.cuda.is_available():
20
  torch.cuda.empty_cache()
21
  gc.collect()
 
34
  print("กำลังโหลดโมเดลและ processor...")
35
 
36
  try:
37
+ # กำหนด paths
38
+ base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
39
  hub_model_path = "Aekanun/thai-handwriting-llm"
40
 
41
+ # ตั้งค่า BitsAndBytes
42
+ bnb_config = BitsAndBytesConfig(
43
+ load_in_4bit=True,
44
+ bnb_4bit_use_double_quant=True,
45
+ bnb_4bit_quant_type="nf4",
46
+ bnb_4bit_compute_dtype=torch.bfloat16
47
+ )
48
+
49
+ # โหลด processor จาก base model
50
+ processor = AutoProcessor.from_pretrained(
51
+ base_model_path,
52
+ use_auth_token=os.environ.get('HUGGING_FACE_HUB_TOKEN')
53
  )
54
 
55
+ # โหลดโมเดลจาก Hub
56
  print("กำลังโหลดโมเดลจาก Hub...")
57
+ model = AutoModelForVision2Seq.from_pretrained(
58
  hub_model_path,
59
  device_map="auto",
60
  torch_dtype=torch.bfloat16,
61
+ quantization_config=bnb_config,
62
  trust_remote_code=True,
63
+ use_auth_token=os.environ.get('HUGGING_FACE_HUB_TOKEN')
64
  )
65
+ print("โหลดโมเดลจาก Hub สำเร็จ!")
66
 
67
  return True
68
  except Exception as e:
 
84
  # Convert to RGB if needed
85
  if image.mode != "RGB":
86
  image = image.convert("RGB")
87
+
88
+ # สร้าง prompt สำหรับการถอดความ
89
  prompt = """Transcribe the Thai handwritten text from the provided image.
90
  Only return the transcription in Thai language."""
91
 
92
+ # สร้าง input สำหรับโมเดล
93
  messages = [
94
  {
95
  "role": "user",
 
100
  }
101
  ]
102
 
103
+ # สร้าง inputs โดยตรงจาก processor
104
  text = processor.apply_chat_template(messages, tokenize=False)
105
  inputs = processor(text=text, images=image, return_tensors="pt")
106
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
107
 
108
+ # ทำนาย
109
  with torch.no_grad():
110
  outputs = model.generate(
111
  **inputs,
 
114
  pad_token_id=processor.tokenizer.pad_token_id
115
  )
116
 
117
+ # แปลงผลลัพธ์
118
  transcription = processor.decode(outputs[0], skip_special_tokens=True)
119
  return transcription.strip()
120
 
 
137
  if __name__ == "__main__":
138
  demo.launch()
139
  else:
140
+ print("ไม่สามารถเริ่มต้นแอปพลิเคชันได้")