Aekanun commited on
Commit
1a517f1
·
1 Parent(s): 592ad8f

fix app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -57
app.py CHANGED
@@ -1,97 +1,89 @@
 
 
 
1
  import torch
2
- from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
3
  from PIL import Image
4
  import gradio as gr
5
 
6
- # Global variables for model and processor
 
 
 
 
 
 
 
7
  model = None
8
  processor = None
9
 
10
- def load_model_and_processor():
11
  global model, processor
12
  try:
13
  model_path = "Aekanun/thai-handwriting-llm"
14
- base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
15
-
16
- print("Loading processor...")
17
- processor = AutoProcessor.from_pretrained(base_model_path)
18
 
19
- print("Loading model...")
20
- bnb_config = BitsAndBytesConfig(
21
- load_in_4bit=True,
22
- bnb_4bit_use_double_quant=True,
23
- bnb_4bit_quant_type="nf4",
24
- bnb_4bit_compute_dtype=torch.bfloat16
25
- )
26
 
27
- model = AutoModelForVision2Seq.from_pretrained(
28
- model_path,
29
- device_map="auto",
30
- torch_dtype=torch.bfloat16,
31
- quantization_config=bnb_config
32
- )
33
  return True
34
  except Exception as e:
35
  print(f"Error loading model: {str(e)}")
36
  return False
37
 
38
- def process_handwriting(image):
39
- global model, processor
40
-
41
  if image is None:
42
  return "กรุณาอัพโหลดรูปภาพ"
43
-
44
  try:
 
45
  if not isinstance(image, Image.Image):
46
  image = Image.fromarray(image)
47
 
48
- prompt = """Transcribe the Thai handwritten text from the provided image.
49
- Only return the transcription in Thai language."""
50
-
51
- messages = [
52
- {
53
- "role": "user",
54
- "content": [
55
- {"type": "text", "text": prompt},
56
- {"type": "image", "image": image}
57
- ],
58
- }
59
- ]
60
-
61
- text = processor.apply_chat_template(messages, tokenize=False)
62
- inputs = processor(text=text, images=image, return_tensors="pt")
63
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
64
 
 
 
 
 
 
65
  with torch.no_grad():
66
  outputs = model.generate(
67
  **inputs,
68
- max_new_tokens=256,
69
- do_sample=False,
70
- pad_token_id=processor.tokenizer.pad_token_id
 
71
  )
72
-
73
- transcription = processor.decode(outputs[0], skip_special_tokens=True)
74
- return transcription
 
75
 
76
  except Exception as e:
77
  return f"เกิดข้อผิดพลาด: {str(e)}"
78
 
79
- # Initialize application
80
  print("Initializing application...")
81
- model_loaded = load_model_and_processor()
82
-
83
- if model_loaded:
84
- print("Creating Gradio interface...")
85
  demo = gr.Interface(
86
- fn=process_handwriting,
87
  inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"),
88
  outputs=gr.Textbox(label="ข้อความที่แปลงได้"),
89
- title="Thai Handwriting to Text",
90
- description="อัพโหลดรูปภาพลายมือเขียนภาษาไทยเพื่อแปลงเป็นข้อความ"
 
91
  )
92
-
93
  if __name__ == "__main__":
94
- print("Launching application...")
95
  demo.launch()
96
  else:
97
- print("Failed to load model and processor. Please check the logs.")
 
1
+ import os
2
+ from huggingface_hub import login
3
+ from transformers import AutoProcessor, AutoModelForVision2Seq
4
  import torch
 
5
  from PIL import Image
6
  import gradio as gr
7
 
8
+ # Login to Hugging Face Hub
9
+ if 'HUGGING_FACE_HUB_TOKEN' in os.environ:
10
+ print("Logging in to Hugging Face Hub...")
11
+ login(token=os.environ['HUGGING_FACE_HUB_TOKEN'])
12
+ else:
13
+ print("Warning: HUGGING_FACE_HUB_TOKEN not found")
14
+
15
+ # Global variables
16
  model = None
17
  processor = None
18
 
19
+ def load_model():
20
  global model, processor
21
  try:
22
  model_path = "Aekanun/thai-handwriting-llm"
23
+ print(f"Loading model and processor from {model_path}...")
 
 
 
24
 
25
+ processor = AutoProcessor.from_pretrained(model_path)
26
+ model = AutoModelForVision2Seq.from_pretrained(model_path)
 
 
 
 
 
27
 
28
+ if torch.cuda.is_available():
29
+ model = model.to("cuda")
30
+
 
 
 
31
  return True
32
  except Exception as e:
33
  print(f"Error loading model: {str(e)}")
34
  return False
35
 
36
+ def process_image(image):
 
 
37
  if image is None:
38
  return "กรุณาอัพโหลดรูปภาพ"
39
+
40
  try:
41
+ # Ensure image is in PIL format
42
  if not isinstance(image, Image.Image):
43
  image = Image.fromarray(image)
44
 
45
+ # Convert to RGB if needed
46
+ if image.mode != "RGB":
47
+ image = image.convert("RGB")
48
+
49
+ # Process image
50
+ inputs = processor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Move to GPU if available
53
+ if torch.cuda.is_available():
54
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
55
+
56
+ # Generate text
57
  with torch.no_grad():
58
  outputs = model.generate(
59
  **inputs,
60
+ max_new_tokens=100,
61
+ num_beams=4,
62
+ pad_token_id=processor.tokenizer.pad_token_id,
63
+ eos_token_id=processor.tokenizer.eos_token_id
64
  )
65
+
66
+ # Decode output
67
+ predicted_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
68
+ return predicted_text.strip()
69
 
70
  except Exception as e:
71
  return f"เกิดข้อผิดพลาด: {str(e)}"
72
 
73
+ # Initialize
74
  print("Initializing application...")
75
+ if load_model():
76
+ # Create Gradio interface
 
 
77
  demo = gr.Interface(
78
+ fn=process_image,
79
  inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"),
80
  outputs=gr.Textbox(label="ข้อความที่แปลงได้"),
81
+ title="Thai Handwriting Recognition",
82
+ description="อัพโหลดรูปภาพลายมือเขียนภาษาไทยเพื่อแปลงเป็นข้อความ",
83
+ examples=[["example1.jpg"], ["example2.jpg"]]
84
  )
85
+
86
  if __name__ == "__main__":
 
87
  demo.launch()
88
  else:
89
+ print("Failed to initialize the application")