from pydantic import BaseModel from llama_cpp import Llama import os import gradio as gr # Not suitable for production from dotenv import load_dotenv from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse import spaces import asyncio import random #from llama_cpp.tokenizers import LlamaTokenizer from peft import PeftModel, LoraConfig, get_peft_model import torch from multiprocessing import Process, Queue from google.cloud import storage import json app = FastAPI() load_dotenv() HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") GOOGLE_CLOUD_BUCKET = os.getenv("GOOGLE_CLOUD_BUCKET") GOOGLE_CLOUD_CREDENTIALS = os.getenv("GOOGLE_CLOUD_CREDENTIALS") gcp_credentials = json.loads(GOOGLE_CLOUD_CREDENTIALS) storage_client = storage.Client.from_service_account_info(gcp_credentials) bucket = storage_client.bucket(GOOGLE_CLOUD_BUCKET) MODEL_NAMES = { "starcoder": "starcoder2-3b-q2_k.gguf", "gemma_2b_it": "gemma-2-2b-it-q2_k.gguf", "llama_3_2_1b": "Llama-3.2-1B.Q2_K.gguf", "gemma_2b_imat": "gemma-2-2b-iq1_s-imat.gguf", "phi_3_mini": "phi-3-mini-128k-instruct-iq2_xxs-imat.gguf", "qwen2_0_5b": "qwen2-0.5b-iq1_s-imat.gguf", } class ModelManager: def __init__(self): self.params = {"n_ctx": 2048, "n_batch": 512, "n_predict": 512, "repeat_penalty": 1.1, "n_threads": 1, "seed": -1, "stop": [""], "tokens": []} # self.tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") # Load from GCS for production self.request_queue = Queue() self.response_queue = Queue() self.models = {} # Dictionary to hold multiple models self.load_models() self.start_processing_processes() def load_model_from_bucket(self, bucket_path): blob = bucket.blob(bucket_path) try: model = Llama(model_path=blob.download_as_string(), **self.params) return model except Exception as e: print(f"Error loading model: {e}") return None def load_models(self): for name, path in MODEL_NAMES.items(): model = self.load_model_from_bucket(path) if model: self.models[name] = model def save_model_to_bucket(self, model, bucket_path): blob = bucket.blob(bucket_path) try: blob.upload_from_string(model.save_pretrained(), content_type='application/octet-stream') except Exception as e: print(f"Error saving model: {e}") def train_model(self): #This function needs a complete overhaul for production use. This is a placeholder. config = LoraConfig(r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") base_model_path = "llama-2-7b-chat/llama-2-7b-chat.Q4_K_M.gguf" try: base_model = self.load_model_from_bucket(base_model_path) if base_model: model = get_peft_model(base_model, config) # Placeholder training data - needs a robust data loading mechanism for batch in [{"question": ["a"], "answer":["b"]}, {"question":["c"], "answer":["d"]}]: inputs = self.tokenizer(batch["question"], return_tensors="pt", padding=True, truncation=True) labels = self.tokenizer(batch["answer"], return_tensors="pt", padding=True, truncation=True) outputs = model(**inputs, labels=labels.input_ids) loss = outputs.loss loss.backward() self.save_model_to_bucket(model, "llama_finetuned/llama_finetuned.gguf") del model del base_model except Exception as e: print(f"Error during training: {e}") def generate_text(self, prompt, model_name): if model_name in self.models: model = self.models[model_name] inputs = self.tokenizer(prompt, return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=100) generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text else: return "Error: Model not found." def start_processing_processes(self): p = Process(target=self.process_requests) p.start() def process_requests(self): while True: request_data = self.request_queue.get() if request_data is None: break inputs, model_name, top_p, top_k, temperature, max_tokens = request_data try: response = self.generate_text(inputs, model_name) self.response_queue.put(response) except Exception as e: print(f"Error during inference: {e}") self.response_queue.put("Error generating text.") model_manager = ModelManager() class ChatRequest(BaseModel): message: str model_name: str @spaces.GPU() async def generate_streaming_response(inputs, model_name): top_p = 0.9 top_k = 50 temperature = 0.7 max_tokens = model_manager.params["n_ctx"] - len(model_manager.tokenizer.encode(inputs)) model_manager.request_queue.put((inputs, model_name, top_p, top_k, temperature, max_tokens)) full_text = model_manager.response_queue.get() async def stream_response(): yield full_text return StreamingResponse(stream_response()) async def process_message(message, model_name): inputs = message.strip() return await generate_streaming_response(inputs, model_name) @app.post("/generate_multimodel") async def api_generate_multimodel(request: Request): data = await request.json() message = data["message"] model_name = data.get("model_name", list(MODEL_NAMES.keys())[0]) if model_name not in MODEL_NAMES: return {"error": "Invalid model name"} return await process_message(message, model_name) iface = gr.Interface(fn=process_message, inputs=[gr.Textbox(lines=2, placeholder="Enter your message here..."), gr.Dropdown(list(MODEL_NAMES.keys()), label="Select Model")], outputs=gr.Markdown(stream=True), title="Unified Multi-Model API", description="Enter a message to get responses from a unified model.") #gradio is not suitable for production if __name__ == "__main__": iface.launch()