# app.py import streamlit as st from fastapi import FastAPI from typing import List import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from IndicTransToolkit import IndicProcessor import json from fastapi.middleware.cors import CORSMiddleware import uvicorn # Initialize FastAPI app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize models and processors model = AutoModelForSeq2SeqLM.from_pretrained( "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True ) ip = IndicProcessor(inference=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(DEVICE) def translate_text(sentences: List[str], target_lang: str): try: src_lang = "eng_Latn" batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang) inputs = tokenizer( batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True, ).to(DEVICE) with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) with tokenizer.as_target_tokenizer(): generated_tokens = tokenizer.batch_decode( generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True, ) translations = ip.postprocess_batch(generated_tokens, lang=target_lang) return { "translations": translations, "source_language": src_lang, "target_language": target_lang, } except Exception as e: raise Exception(f"Translation failed: {str(e)}") # FastAPI routes @app.get("/health") async def health_check(): return {"status": "healthy"} @app.post("/translate") async def translate_endpoint(sentences: List[str], target_lang: str): try: result = translate_text(sentences=sentences, target_lang=target_lang) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # # Streamlit interface # def main(): # st.title("Indic Language Translator") # # Input text # text_input = st.text_area("Enter text to translate:", "Hello, how are you?") # # Language selection # target_languages = { # "Hindi": "hin_Deva", # "Bengali": "ben_Beng", # "Tamil": "tam_Taml", # "Telugu": "tel_Telu", # "Marathi": "mar_Deva", # "Gujarati": "guj_Gujr", # "Kannada": "kan_Knda", # "Malayalam": "mal_Mlym", # "Punjabi": "pan_Guru", # "Odia": "ori_Orya", # } # target_lang = st.selectbox( # "Select target language:", options=list(target_languages.keys()) # ) # if st.button("Translate"): # try: # result = translate_text( # sentences=[text_input], target_lang=target_languages[target_lang] # ) # st.success("Translation:") # st.write(result["translations"][0]) # except Exception as e: # st.error(f"Translation failed: {str(e)}") # # Add API documentation # st.markdown("---") # st.header("API Documentation") # st.markdown( # """ # To use the translation API, send POST requests to: # ``` # https://darshankr-trans-en-indic.hf.space/translate # ``` # Request body format: # ```json # { # "sentences": ["Your text here"], # "target_lang": "hin_Deva" # } # ``` # """ # ) # st.markdown("Available target languages:") # for lang, code in target_languages.items(): # st.markdown(f"- {lang}: `{code}`") # if __name__ == "__main__": # # Run both Streamlit and FastAPI # import threading # def run_fastapi(): # uvicorn.run(api, host="0.0.0.0", port=8000) # # Start FastAPI in a separate thread # api_thread = threading.Thread(target=run_fastapi) # api_thread.start() # # Run Streamlit # main()