darshankr commited on
Commit
0b7c166
·
verified ·
1 Parent(s): df8f230

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -73
app.py CHANGED
@@ -1,71 +1,46 @@
1
  # app.py
2
  import streamlit as st
 
 
 
3
  import torch
 
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
  from IndicTransToolkit import IndicProcessor
6
- from typing import List
7
- import sys
8
- from starlette.applications import Starlette
9
- from starlette.routing import Mount
10
- from starlette.staticfiles import StaticFiles
11
- import nest_asyncio
12
- from api import app
13
-
14
- # Enable nested event loops
15
- nest_asyncio.apply()
16
 
17
- # Initialize models and processors (lazy loading)
18
- @st.cache_resource
19
- def load_models():
20
- model = AutoModelForSeq2SeqLM.from_pretrained(
21
- "ai4bharat/indictrans2-en-indic-1B",
22
- trust_remote_code=True
23
- )
24
- tokenizer = AutoTokenizer.from_pretrained(
25
- "ai4bharat/indictrans2-en-indic-1B",
26
- trust_remote_code=True
27
- )
28
- ip = IndicProcessor(inference=True)
29
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
- model = model.to(DEVICE)
31
- return model, tokenizer, ip, DEVICE
32
 
33
- # Global variables for models
34
- model, tokenizer, ip, DEVICE = load_models()
35
 
36
  def translate_text(sentences: List[str], target_lang: str):
37
  try:
38
  src_lang = "eng_Latn"
39
- batch = ip.preprocess_batch(
40
- sentences,
41
- src_lang=src_lang,
42
- tgt_lang=target_lang
43
- )
44
- inputs = tokenizer(
45
- batch,
46
- truncation=True,
47
- padding="longest",
48
- return_tensors="pt",
49
- return_attention_mask=True
50
- ).to(DEVICE)
51
-
52
  with torch.no_grad():
53
  generated_tokens = model.generate(
54
- **inputs,
55
  use_cache=True,
56
  min_length=0,
57
  max_length=256,
58
  num_beams=5,
59
  num_return_sequences=1
60
  )
61
-
62
  with tokenizer.as_target_tokenizer():
63
  generated_tokens = tokenizer.batch_decode(
64
  generated_tokens.detach().cpu().tolist(),
65
  skip_special_tokens=True,
66
  clean_up_tokenization_spaces=True
67
  )
68
-
69
  translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
70
  return {
71
  "translations": translations,
@@ -75,12 +50,13 @@ def translate_text(sentences: List[str], target_lang: str):
75
  except Exception as e:
76
  raise Exception(f"Translation failed: {str(e)}")
77
 
78
- def streamlit_app():
 
79
  st.title("Indic Language Translator")
80
-
81
  # Input text
82
  text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
83
-
84
  # Language selection
85
  target_languages = {
86
  "Hindi": "hin_Deva",
@@ -95,17 +71,13 @@ def streamlit_app():
95
  "Odia": "ori_Orya"
96
  }
97
 
98
- target_lang = st.selectbox(
99
- "Select target language:",
100
- options=list(target_languages.keys())
101
- )
102
-
103
  if st.button("Translate"):
104
  try:
105
- result = translate_text(
106
- sentences=[text_input],
107
- target_lang=target_languages[target_lang]
108
- )
109
  st.success("Translation:")
110
  st.write(result["translations"][0])
111
  except Exception as e:
@@ -116,9 +88,8 @@ def streamlit_app():
116
  st.header("API Documentation")
117
  st.markdown("""
118
  To use the translation API, send POST requests to:
119
- ```
120
- https://YOUR-SPACE-NAME.hf.space/api/translate
121
- ```
122
  Request body format:
123
  ```json
124
  {
@@ -126,21 +97,19 @@ def streamlit_app():
126
  "target_lang": "hin_Deva"
127
  }
128
  ```
129
- """)
130
- st.markdown("Available target languages:")
131
- for lang, code in target_languages.items():
132
- st.markdown(f"- {lang}: `{code}`")
133
 
134
- def create_app():
135
- routes = [
136
- Mount("/api", app),
137
- Mount("/", StaticFiles(directory="static", html=True), name="static"),
138
- ]
139
- return Starlette(routes=routes)
 
 
 
 
 
 
140
 
141
  if __name__ == "__main__":
142
- if "streamlit" in sys.argv[0]:
143
- streamlit_app()
144
- else:
145
- import uvicorn
146
- uvicorn.run(create_app(), host="0.0.0.0", port=7860)
 
1
  # app.py
2
  import streamlit as st
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ from typing import List
6
  import torch
7
+ import asyncio
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from IndicTransToolkit import IndicProcessor
10
+ import requests
11
+ import json
 
 
 
 
 
 
 
 
12
 
13
+ # Initialize models and processors
14
+ model = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True)
15
+ tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True)
16
+ ip = IndicProcessor(inference=True)
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ model = model.to(DEVICE)
20
 
21
  def translate_text(sentences: List[str], target_lang: str):
22
  try:
23
  src_lang = "eng_Latn"
24
+ batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang)
25
+ inputs = tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(DEVICE)
26
+
 
 
 
 
 
 
 
 
 
 
27
  with torch.no_grad():
28
  generated_tokens = model.generate(
29
+ inputs,
30
  use_cache=True,
31
  min_length=0,
32
  max_length=256,
33
  num_beams=5,
34
  num_return_sequences=1
35
  )
36
+
37
  with tokenizer.as_target_tokenizer():
38
  generated_tokens = tokenizer.batch_decode(
39
  generated_tokens.detach().cpu().tolist(),
40
  skip_special_tokens=True,
41
  clean_up_tokenization_spaces=True
42
  )
43
+
44
  translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
45
  return {
46
  "translations": translations,
 
50
  except Exception as e:
51
  raise Exception(f"Translation failed: {str(e)}")
52
 
53
+ # Streamlit interface
54
+ def main():
55
  st.title("Indic Language Translator")
56
+
57
  # Input text
58
  text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
59
+
60
  # Language selection
61
  target_languages = {
62
  "Hindi": "hin_Deva",
 
71
  "Odia": "ori_Orya"
72
  }
73
 
74
+ target_lang = st.selectbox("Select target language:", options=list(target_languages.keys()))
75
+
 
 
 
76
  if st.button("Translate"):
77
  try:
78
+ result = translate_text(sentences=[text_input], target_lang=target_languages[target_lang])
79
+
80
+ # Display result
 
81
  st.success("Translation:")
82
  st.write(result["translations"][0])
83
  except Exception as e:
 
88
  st.header("API Documentation")
89
  st.markdown("""
90
  To use the translation API, send POST requests to:
91
+ https://USERNAME-SPACE_NAME.hf.space/translate
92
+
 
93
  Request body format:
94
  ```json
95
  {
 
97
  "target_lang": "hin_Deva"
98
  }
99
  ```
 
 
 
 
100
 
101
+ Available target languages:
102
+ - Hindi: hin_Deva
103
+ - Bengali: ben_Beng
104
+ - Tamil: tam_Taml
105
+ - Telugu: tel_Telu
106
+ - Marathi: mar_Deva
107
+ - Gujarati: guj_Gujr
108
+ - Kannada: kan_Knda
109
+ - Malayalam: mal_Mlym
110
+ - Punjabi: pan_Guru
111
+ - Odia: ori_Orya
112
+ """)
113
 
114
  if __name__ == "__main__":
115
+ main()