import os import torch import gradio as gr from transformers import MarianMTModel, MarianTokenizer, pipeline, AutoTokenizer from huggingface_hub import login # Read the token from the environment variable HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Authenticate with Hugging Face if HUGGINGFACE_TOKEN: login(token=HUGGINGFACE_TOKEN) else: raise ValueError("Hugging Face token not found in environment variables.") # Define model and tokenizer for translation between Romanian, French, and English rfr_md = "Helsinki-NLP/opus-mt-ro-fr" frr_md = "Helsinki-NLP/opus-mt-fr-en" enr_md = "Helsinki-NLP/opus-mt-en-ro" rfr_token = MarianTokenizer.from_pretrained(rfr_md) rfr_model = MarianMTModel.from_pretrained(rfr_md) fren_token = MarianTokenizer.from_pretrained(frr_md) fren_model = MarianMTModel.from_pretrained(frr_md) enr_token = MarianTokenizer.from_pretrained(enr_md) enr_model = MarianMTModel.from_pretrained(enr_md) # Load the Gemma model for text generation, ensuring it runs on CPU gemma_model = "stabilityai/stablelm-2-1_6b-chat" gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model) pipe = pipeline( "text-generation", model=gemma_model, tokenizer=gemma_tokenizer, device="cpu" # Use CPU ) # Function to split text into smaller blocks for translation def char_split(text, tokenizer, max_length=498): tokens = tokenizer(text, return_tensors="pt", truncation=False, padding=False)["input_ids"][0] blocks_ = [] start = 0 while start < len(tokens): end = min(start + max_length, len(tokens)) blocks_.append(tokens[start:end]) start = end return blocks_ # Function to translate the text block by block def translate(text, model, tokenizer, max_length=500): token_blocks = char_split(text, tokenizer, max_length) text_en = "" for blk_ in token_blocks: blk_char = tokenizer.decode(blk_, skip_special_tokens=True) translated = model.generate(**tokenizer(blk_char, return_tensors="pt", padding=True, truncation=True)) text_en += tokenizer.decode(translated[0], skip_special_tokens=True) + " " return text_en.strip() # Function to remove formatting symbols def rm_rf(text): import re return re.sub(r'\*+', '', text) # Generate text based on Romanian input def generate(text): fr_txt = translate(text, rfr_model, rfr_token) en_txt = translate(fr_txt, fren_model, fren_token) sequences = pipe( en_txt, max_new_tokens=2048, do_sample=True, return_full_text=False, ) generated_text = sequences[0]['generated_text'] cl_txt = rm_rf(generated_text) ro_txt = translate(cl_txt, enr_model, enr_token) return ro_txt # Create the Gradio interface interface = gr.Interface( fn=generate, inputs=gr.Textbox(label="prompt:", lines=2, placeholder="prompt..."), outputs="text", title="Gemma Romanian", description="romanian gemma using nlps." ) # Launch the Gradio app interface.launch()