import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import os import easyocr import numpy as np from PIL import Image import torch # Đường dẫn mô hình MODEL_PATH = "content/envit5-translation" # Khởi tạo EasyOCR reader với GPU nếu có gpu = True if torch.cuda.is_available() else False reader = None # Khởi tạo reader là None trước def initialize_ocr(): global reader if reader is None: reader = easyocr.Reader(['vi', 'en'], gpu=gpu) return reader # Hàm kiểm tra và tải mô hình nếu cần def check_and_download_model(): if not os.path.exists(MODEL_PATH): model_name = "VietAI/envit5-translation" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Lưu mô hình tokenizer.save_pretrained(MODEL_PATH) model.save_pretrained(MODEL_PATH) # Hàm trích xuất text từ ảnh def extract_text_from_image(image): if image is None: return "Vui lòng tải lên một ảnh!" try: global reader if reader is None: reader = initialize_ocr() # Chuyển đổi ảnh sang định dạng phù hợp if isinstance(image, np.ndarray): results = reader.readtext(image) else: img = Image.open(image) img_array = np.array(img) results = reader.readtext(img_array) # Kết hợp tất cả text tìm được extracted_text = ' '.join([text[1] for text in results]) return extracted_text if extracted_text.strip() else "Không tìm thấy text trong ảnh!" except Exception as e: return f"Lỗi khi xử lý ảnh: {str(e)}" # Hàm dịch def translate_text(input_text, direction): if not input_text or not input_text.strip(): return "Vui lòng nhập văn bản cần dịch!" try: # Tải mô hình và tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH) # Chuẩn bị đầu vào với định hướng dịch prefix = "vi: " if direction == "Vie → Eng" else "en: " inputs = [f"{prefix}{input_text}"] encoded_inputs = tokenizer(inputs, return_tensors="pt", padding=True).input_ids # Dịch văn bản outputs = model.generate(encoded_inputs, max_length=512) translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] return translated_text except Exception as e: return f"Lỗi: {str(e)}" # Hàm tạo file tải xuống def create_txt_file(output_text): if not output_text: return None file_path = "translation_output.txt" with open(file_path, "w", encoding="utf-8") as f: f.write(output_text) return file_path # Hàm xử lý khi tải ảnh lên def handle_image_upload(image): if image is None: return "Vui lòng tải lên một ảnh!" return extract_text_from_image(image) # Tải mô hình khi khởi chạy check_and_download_model() # Tạo giao diện Gradio with gr.Blocks() as interface: gr.Markdown("## Ứng dụng Dịch Máy VietAI với OCR") gr.Markdown("Công cụ dịch tự động từ tiếng Việt sang tiếng Anh hoặc ngược lại, hỗ trợ OCR và tải kết quả dưới dạng file .txt.") with gr.Tabs(): with gr.Tab("Nhập Text"): with gr.Row(): text_input = gr.Textbox(label="Văn bản đầu vào", lines=8, placeholder="Nhập văn bản cần dịch...") text_output = gr.Textbox(label="Kết quả dịch", lines=8, placeholder="Bản dịch sẽ hiện ở đây...") with gr.Tab("OCR từ Ảnh"): with gr.Column(): image_input = gr.Image(label="Tải ảnh lên", type="filepath") ocr_button = gr.Button("Trích xuất Text") ocr_output = gr.Textbox(label="Text trích xuất từ ảnh", lines=8, placeholder="Text trích xuất từ ảnh sẽ hiện ở đây...") direction = gr.Radio( ["Vie → Eng", "Eng → Vie"], label="Hướng dịch", value="Vie → Eng" ) translate_button = gr.Button("Dịch") download_button = gr.Button("Tải xuống kết quả") # Kết nối các thành phần ocr_button.click( handle_image_upload, inputs=[image_input], outputs=ocr_output ) # Xử lý dịch cho cả hai tab def translate_wrapper(text_input, ocr_input, active_tab, direction): input_text = text_input if active_tab == "Nhập Text" else ocr_input return translate_text(input_text, direction) translate_button.click( translate_wrapper, inputs=[text_input, ocr_output, gr.State("Nhập Text"), direction], outputs=text_output ) download_button.click( create_txt_file, inputs=[text_output], outputs=gr.File(label="Tải xuống") ) if __name__ == "__main__": # Khởi tạo OCR reader initialize_ocr() # Chạy ứng dụng interface.launch()