import os import warnings import torch import gc from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig from PIL import Image import gradio as gr warnings.filterwarnings('ignore') os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Global variables model = None processor = None if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() print("เคลียร์ CUDA cache เรียบร้อยแล้ว") def load_model_and_processor(): """โหลดโมเดลและ processor""" global model, processor print("กำลังโหลดโมเดลและ processor...") try: base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct" hub_model_path = "Aekanun/thai-handwriting-llm" # ตั้งค่า BitsAndBytes แบบเดียวกับต้นฉบับ bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) # โหลด processor แบบเดียวกับต้นฉบับ (ไม่มี token) processor = AutoProcessor.from_pretrained(base_model_path) # โหลดโมเดลจาก Hub แบบเดียวกับต้นฉบับ print("กำลังโหลดโมเดลจาก Hub...") model = AutoModelForVision2Seq.from_pretrained( hub_model_path, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=bnb_config, trust_remote_code=True ) print("โหลดโมเดลจาก Hub สำเร็จ!") return True except Exception as e: print(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}") return False def process_handwriting(image): if image is None: return "กรุณาอัพโหลดรูปภาพ" try: if not isinstance(image, Image.Image): image = Image.fromarray(image) if image.mode != "RGB": image = image.convert("RGB") prompt = """Transcribe the Thai handwritten text from the provided image. Only return the transcription in Thai language.""" messages = [ { "role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image", "image": image} ], } ] text = processor.apply_chat_template(messages, tokenize=False) inputs = processor(text=text, images=image, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, do_sample=False, pad_token_id=processor.tokenizer.pad_token_id ) transcription = processor.decode(outputs[0], skip_special_tokens=True) return transcription.strip() except Exception as e: return f"เกิดข้อผิดพลาด: {str(e)}" print("กำลังเริ่มต้นแอปพลิเคชัน...") if load_model_and_processor(): demo = gr.Interface( fn=process_handwriting, inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"), outputs=gr.Textbox(label="ข้อความที่แปลงได้"), title="Thai Handwriting Recognition", description="อัพโหลดรูปภาพลายมือเขียนภาษาไทยเพื่อแปลงเป็นข้อความ" ) if __name__ == "__main__": demo.launch() else: print("ไม่สามารถเริ่มต้นแอปพลิเคชันได้")