import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import BitsAndBytesConfig from accelerate import infer_auto_device_map # Load the model name model_name = "ai4bharat/Airavata" # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) # Create a BitsAndBytesConfig for quantization bnb_config = BitsAndBytesConfig( load_in_8bit=True, # Set this to True for 8-bit loading # Optionally, you can specify more parameters based on your needs ) # Load the model using the BitsAndBytesConfig model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config # Use the BitsAndBytesConfig ) # Now infer the device map device_map = infer_auto_device_map(model) # Move model to the appropriate device based on device_map model.to(device_map) # Define the inference function def generate_text(prompt): inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate(**inputs) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Create the Gradio interface interface = gr.Interface( fn=generate_text, inputs="text", outputs="text", title="Airavata Text Generation Model", description="This is the AI4Bharat Airavata model for text generation in Indic languages." ) # Launch the interface interface.launch()