fargerm commited on
Commit
cce785c
1 Parent(s): e1452a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -1
app.py CHANGED
@@ -15,6 +15,20 @@ MODELS = {
15
  'Bengali': "Helsinki-NLP/opus-mt-en-bn",
16
  }
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def translate_text(text, target_lang):
19
  # Load the appropriate model and tokenizer for the target language
20
  model_name = MODELS.get(target_lang)
@@ -26,8 +40,15 @@ def translate_text(text, target_lang):
26
 
27
  # Encode the text and prepare it for translation
28
  encoded_text = tokenizer(text, return_tensors="pt")
 
 
 
 
 
 
29
  # Translate text
30
- translated = model.generate(**encoded_text, forced_bos_token_id=tokenizer.get_lang_id(target_lang))
 
31
  # Decode the translated text
32
  translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
33
  return translated_text
@@ -47,3 +68,4 @@ if st.button('Translate'):
47
  else:
48
  st.error("Please enter text to translate.")
49
 
 
 
15
  'Bengali': "Helsinki-NLP/opus-mt-en-bn",
16
  }
17
 
18
+ # Manually defined language codes for different language models
19
+ LANG_CODE_MAP = {
20
+ 'French': 'fr',
21
+ 'Spanish': 'es',
22
+ 'German': 'de',
23
+ 'Chinese': 'zh',
24
+ 'Russian': 'ru',
25
+ 'Japanese': 'ja',
26
+ 'Arabic': 'ar',
27
+ 'Urdu': 'ur',
28
+ 'Hindi': 'hi',
29
+ 'Bengali': 'bn',
30
+ }
31
+
32
  def translate_text(text, target_lang):
33
  # Load the appropriate model and tokenizer for the target language
34
  model_name = MODELS.get(target_lang)
 
40
 
41
  # Encode the text and prepare it for translation
42
  encoded_text = tokenizer(text, return_tensors="pt")
43
+
44
+ # Get the language code for forced_bos_token_id
45
+ lang_code = LANG_CODE_MAP.get(target_lang)
46
+ if not lang_code:
47
+ return "Error: Language code not found."
48
+
49
  # Translate text
50
+ translated = model.generate(**encoded_text, forced_bos_token_id=tokenizer.lang_code_to_id.get(lang_code))
51
+
52
  # Decode the translated text
53
  translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
54
  return translated_text
 
68
  else:
69
  st.error("Please enter text to translate.")
70
 
71
+