Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import os | |
import easyocr | |
import numpy as np | |
from PIL import Image | |
import torch | |
# Đường dẫn mô hình | |
MODEL_PATH = "content/envit5-translation" | |
# Khởi tạo EasyOCR reader với GPU nếu có | |
gpu = True if torch.cuda.is_available() else False | |
reader = None # Khởi tạo reader là None trước | |
def initialize_ocr(): | |
global reader | |
if reader is None: | |
reader = easyocr.Reader(['vi', 'en'], gpu=gpu) | |
return reader | |
# Hàm kiểm tra và tải mô hình nếu cần | |
def check_and_download_model(): | |
if not os.path.exists(MODEL_PATH): | |
model_name = "VietAI/envit5-translation" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# Lưu mô hình | |
tokenizer.save_pretrained(MODEL_PATH) | |
model.save_pretrained(MODEL_PATH) | |
# Hàm trích xuất text từ ảnh | |
def extract_text_from_image(image): | |
if image is None: | |
return "Vui lòng tải lên một ảnh!" | |
try: | |
global reader | |
if reader is None: | |
reader = initialize_ocr() | |
# Chuyển đổi ảnh sang định dạng phù hợp | |
if isinstance(image, np.ndarray): | |
results = reader.readtext(image) | |
else: | |
img = Image.open(image) | |
img_array = np.array(img) | |
results = reader.readtext(img_array) | |
# Kết hợp tất cả text tìm được | |
extracted_text = ' '.join([text[1] for text in results]) | |
return extracted_text if extracted_text.strip() else "Không tìm thấy text trong ảnh!" | |
except Exception as e: | |
return f"Lỗi khi xử lý ảnh: {str(e)}" | |
# Hàm dịch | |
def translate_text(input_text, direction): | |
if not input_text or not input_text.strip(): | |
return "Vui lòng nhập văn bản cần dịch!" | |
try: | |
# Tải mô hình và tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH) | |
# Chuẩn bị đầu vào với định hướng dịch | |
prefix = "vi: " if direction == "Vie → Eng" else "en: " | |
inputs = [f"{prefix}{input_text}"] | |
encoded_inputs = tokenizer(inputs, return_tensors="pt", padding=True).input_ids | |
# Dịch văn bản | |
outputs = model.generate(encoded_inputs, max_length=512) | |
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
return translated_text | |
except Exception as e: | |
return f"Lỗi: {str(e)}" | |
# Hàm tạo file tải xuống | |
def create_txt_file(output_text): | |
if not output_text: | |
return None | |
file_path = "translation_output.txt" | |
with open(file_path, "w", encoding="utf-8") as f: | |
f.write(output_text) | |
return file_path | |
# Hàm xử lý khi tải ảnh lên | |
def handle_image_upload(image): | |
if image is None: | |
return "Vui lòng tải lên một ảnh!" | |
return extract_text_from_image(image) | |
# Tải mô hình khi khởi chạy | |
check_and_download_model() | |
# Tạo giao diện Gradio | |
with gr.Blocks() as interface: | |
gr.Markdown("## Ứng dụng Dịch Máy VietAI với OCR") | |
gr.Markdown("Công cụ dịch tự động từ tiếng Việt sang tiếng Anh hoặc ngược lại, hỗ trợ OCR và tải kết quả dưới dạng file .txt.") | |
with gr.Tabs(): | |
with gr.Tab("Nhập Text"): | |
with gr.Row(): | |
text_input = gr.Textbox(label="Văn bản đầu vào", lines=8, placeholder="Nhập văn bản cần dịch...") | |
text_output = gr.Textbox(label="Kết quả dịch", lines=8, placeholder="Bản dịch sẽ hiện ở đây...") | |
with gr.Tab("OCR từ Ảnh"): | |
with gr.Column(): | |
image_input = gr.Image(label="Tải ảnh lên", type="filepath") | |
ocr_button = gr.Button("Trích xuất Text") | |
ocr_output = gr.Textbox(label="Text trích xuất từ ảnh", lines=8, placeholder="Text trích xuất từ ảnh sẽ hiện ở đây...") | |
direction = gr.Radio( | |
["Vie → Eng", "Eng → Vie"], | |
label="Hướng dịch", | |
value="Vie → Eng" | |
) | |
translate_button = gr.Button("Dịch") | |
download_button = gr.Button("Tải xuống kết quả") | |
# Kết nối các thành phần | |
ocr_button.click( | |
handle_image_upload, | |
inputs=[image_input], | |
outputs=ocr_output | |
) | |
# Xử lý dịch cho cả hai tab | |
def translate_wrapper(text_input, ocr_input, active_tab, direction): | |
input_text = text_input if active_tab == "Nhập Text" else ocr_input | |
return translate_text(input_text, direction) | |
translate_button.click( | |
translate_wrapper, | |
inputs=[text_input, ocr_output, gr.State("Nhập Text"), direction], | |
outputs=text_output | |
) | |
download_button.click( | |
create_txt_file, | |
inputs=[text_output], | |
outputs=gr.File(label="Tải xuống") | |
) | |
if __name__ == "__main__": | |
# Khởi tạo OCR reader | |
initialize_ocr() | |
# Chạy ứng dụng | |
interface.launch() |