Translate / app.py
ArrcttacsrjksX's picture
Update app.py
8bcc9bb verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
import easyocr
import numpy as np
from PIL import Image
import torch
# Đường dẫn mô hình
MODEL_PATH = "content/envit5-translation"
# Khởi tạo EasyOCR reader với GPU nếu có
gpu = True if torch.cuda.is_available() else False
reader = None # Khởi tạo reader là None trước
def initialize_ocr():
global reader
if reader is None:
reader = easyocr.Reader(['vi', 'en'], gpu=gpu)
return reader
# Hàm kiểm tra và tải mô hình nếu cần
def check_and_download_model():
if not os.path.exists(MODEL_PATH):
model_name = "VietAI/envit5-translation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Lưu mô hình
tokenizer.save_pretrained(MODEL_PATH)
model.save_pretrained(MODEL_PATH)
# Hàm trích xuất text từ ảnh
def extract_text_from_image(image):
if image is None:
return "Vui lòng tải lên một ảnh!"
try:
global reader
if reader is None:
reader = initialize_ocr()
# Chuyển đổi ảnh sang định dạng phù hợp
if isinstance(image, np.ndarray):
results = reader.readtext(image)
else:
img = Image.open(image)
img_array = np.array(img)
results = reader.readtext(img_array)
# Kết hợp tất cả text tìm được
extracted_text = ' '.join([text[1] for text in results])
return extracted_text if extracted_text.strip() else "Không tìm thấy text trong ảnh!"
except Exception as e:
return f"Lỗi khi xử lý ảnh: {str(e)}"
# Hàm dịch
def translate_text(input_text, direction):
if not input_text or not input_text.strip():
return "Vui lòng nhập văn bản cần dịch!"
try:
# Tải mô hình và tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
# Chuẩn bị đầu vào với định hướng dịch
prefix = "vi: " if direction == "Vie → Eng" else "en: "
inputs = [f"{prefix}{input_text}"]
encoded_inputs = tokenizer(inputs, return_tensors="pt", padding=True).input_ids
# Dịch văn bản
outputs = model.generate(encoded_inputs, max_length=512)
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return translated_text
except Exception as e:
return f"Lỗi: {str(e)}"
# Hàm tạo file tải xuống
def create_txt_file(output_text):
if not output_text:
return None
file_path = "translation_output.txt"
with open(file_path, "w", encoding="utf-8") as f:
f.write(output_text)
return file_path
# Hàm xử lý khi tải ảnh lên
def handle_image_upload(image):
if image is None:
return "Vui lòng tải lên một ảnh!"
return extract_text_from_image(image)
# Tải mô hình khi khởi chạy
check_and_download_model()
# Tạo giao diện Gradio
with gr.Blocks() as interface:
gr.Markdown("## Ứng dụng Dịch Máy VietAI với OCR")
gr.Markdown("Công cụ dịch tự động từ tiếng Việt sang tiếng Anh hoặc ngược lại, hỗ trợ OCR và tải kết quả dưới dạng file .txt.")
with gr.Tabs():
with gr.Tab("Nhập Text"):
with gr.Row():
text_input = gr.Textbox(label="Văn bản đầu vào", lines=8, placeholder="Nhập văn bản cần dịch...")
text_output = gr.Textbox(label="Kết quả dịch", lines=8, placeholder="Bản dịch sẽ hiện ở đây...")
with gr.Tab("OCR từ Ảnh"):
with gr.Column():
image_input = gr.Image(label="Tải ảnh lên", type="filepath")
ocr_button = gr.Button("Trích xuất Text")
ocr_output = gr.Textbox(label="Text trích xuất từ ảnh", lines=8, placeholder="Text trích xuất từ ảnh sẽ hiện ở đây...")
direction = gr.Radio(
["Vie → Eng", "Eng → Vie"],
label="Hướng dịch",
value="Vie → Eng"
)
translate_button = gr.Button("Dịch")
download_button = gr.Button("Tải xuống kết quả")
# Kết nối các thành phần
ocr_button.click(
handle_image_upload,
inputs=[image_input],
outputs=ocr_output
)
# Xử lý dịch cho cả hai tab
def translate_wrapper(text_input, ocr_input, active_tab, direction):
input_text = text_input if active_tab == "Nhập Text" else ocr_input
return translate_text(input_text, direction)
translate_button.click(
translate_wrapper,
inputs=[text_input, ocr_output, gr.State("Nhập Text"), direction],
outputs=text_output
)
download_button.click(
create_txt_file,
inputs=[text_output],
outputs=gr.File(label="Tải xuống")
)
if __name__ == "__main__":
# Khởi tạo OCR reader
initialize_ocr()
# Chạy ứng dụng
interface.launch()