hereoncollab commited on
Commit
eed0648
1 Parent(s): b7bdd8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -13
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
 
2
  import gradio as gr
3
- from transformers import pipeline
4
  from huggingface_hub import login
5
 
6
  # Read the token from the environment variable
@@ -12,23 +13,80 @@ if HUGGINGFACE_TOKEN:
12
  else:
13
  raise ValueError("Hugging Face token not found in environment variables.")
14
 
15
- # Initialize the text generation pipeline
16
- pipe = pipeline("text-generation", model="FacebookAI/xlm-roberta-base")
 
 
17
 
18
- def generate_response(user_input):
19
- # Generate text based on the user's input
20
- response = pipe(user_input, max_length=100, num_return_sequences=1)
21
- # Extract the generated text
22
- generated_text = response[0]['generated_text']
23
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Create the Gradio interface
26
  interface = gr.Interface(
27
- fn=generate_response,
28
- inputs=gr.Textbox(label="prompt:", lines=2, placeholder="prompt"),
29
  outputs="text",
30
- title="Gemma",
31
- description="Prompt gemma-2b"
32
  )
33
 
34
  # Launch the Gradio app
 
1
  import os
2
+ import torch
3
  import gradio as gr
4
+ from transformers import MarianMTModel, MarianTokenizer, pipeline, AutoTokenizer
5
  from huggingface_hub import login
6
 
7
  # Read the token from the environment variable
 
13
  else:
14
  raise ValueError("Hugging Face token not found in environment variables.")
15
 
16
+ # Define model and tokenizer for translation between Romanian, French, and English
17
+ rfr_md = "Helsinki-NLP/opus-mt-ro-fr"
18
+ frr_md = "Helsinki-NLP/opus-mt-fr-en"
19
+ enr_md = "Helsinki-NLP/opus-mt-en-ro"
20
 
21
+ rfr_token = MarianTokenizer.from_pretrained(rfr_md)
22
+ rfr_model = MarianMTModel.from_pretrained(rfr_md)
23
+ fren_token = MarianTokenizer.from_pretrained(frr_md)
24
+ fren_model = MarianMTModel.from_pretrained(frr_md)
25
+ enr_token = MarianTokenizer.from_pretrained(enr_md)
26
+ enr_model = MarianMTModel.from_pretrained(enr_md)
27
+
28
+ # Load the Gemma model for text generation, ensuring it runs on CPU
29
+ gemma_model = "google/gemma-2-2b-it"
30
+ gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model)
31
+
32
+ pipe = pipeline(
33
+ "text-generation",
34
+ model=gemma_model,
35
+ tokenizer=gemma_tokenizer,
36
+ device="cpu" # Use CPU
37
+ )
38
+
39
+ # Function to split text into smaller blocks for translation
40
+ def char_split(text, tokenizer, max_length=498):
41
+ tokens = tokenizer(text, return_tensors="pt", truncation=False, padding=False)["input_ids"][0]
42
+ blocks_ = []
43
+ start = 0
44
+ while start < len(tokens):
45
+ end = min(start + max_length, len(tokens))
46
+ blocks_.append(tokens[start:end])
47
+ start = end
48
+ return blocks_
49
+
50
+ # Function to translate the text block by block
51
+ def translate(text, model, tokenizer, max_length=500):
52
+ token_blocks = char_split(text, tokenizer, max_length)
53
+ text_en = ""
54
+
55
+ for blk_ in token_blocks:
56
+ blk_char = tokenizer.decode(blk_, skip_special_tokens=True)
57
+ translated = model.generate(**tokenizer(blk_char, return_tensors="pt", padding=True, truncation=True))
58
+ text_en += tokenizer.decode(translated[0], skip_special_tokens=True) + " "
59
+ return text_en.strip()
60
+
61
+ # Function to remove formatting symbols
62
+ def rm_rf(text):
63
+ import re
64
+ return re.sub(r'\*+', '', text)
65
+
66
+ # Generate text based on Romanian input
67
+ def generate(text):
68
+ fr_txt = translate(text, rfr_model, rfr_token)
69
+ en_txt = translate(fr_txt, fren_model, fren_token)
70
+
71
+ sequences = pipe(
72
+ en_txt,
73
+ max_new_tokens=2048,
74
+ do_sample=True,
75
+ return_full_text=False,
76
+ )
77
+
78
+ generated_text = sequences[0]['generated_text']
79
+ cl_txt = rm_rf(generated_text)
80
+ ro_txt = translate(cl_txt, enr_model, enr_token)
81
+ return ro_txt
82
 
83
  # Create the Gradio interface
84
  interface = gr.Interface(
85
+ fn=generate,
86
+ inputs=gr.Textbox(label="prompt:", lines=2, placeholder="prompt..."),
87
  outputs="text",
88
+ title="Gemma Romanian",
89
+ description="romanian gemma using nlps."
90
  )
91
 
92
  # Launch the Gradio app