from typing import List, Dict from transformers import AutoModelForCausalLM, AutoTokenizer import torch class EndpointHandler: def __init__(self, path: str): # Load model and tokenizer self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float32, # Use float32 for CPU device_map="auto" ) # Set up generation parameters self.default_params = { "max_length": 1000, "temperature": 0.7, "top_p": 0.7, "top_k": 50, "repetition_penalty": 1.0, "do_sample": True, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id } def __call__(self, data: Dict): try: # Handle input if isinstance(data.get("inputs"), str): input_text = data["inputs"] else: # Extract messages from input messages = data.get("inputs", {}).get("messages", []) if not messages: return {"error": "No messages provided"} # Format input text as array inputs = [] for msg in messages: role = msg.get("role", "") content = msg.get("content", "") inputs.append(f"{role}: {content}") input_text = "\n".join(inputs) # Get generation parameters params = {**self.default_params} if "parameters" in data: params.update(data["parameters"]) # Remove pad_token_id from params if it's going to be set explicitly params.pop('pad_token_id', None) # Tokenize input tokenizer_output = self.tokenizer( input_text, return_tensors="pt", padding=True, truncation=True, max_length=512, return_attention_mask=True ) # Generate response with torch.no_grad(): outputs = self.model.generate( tokenizer_output["input_ids"], attention_mask=tokenizer_output["attention_mask"], pad_token_id=self.tokenizer.pad_token_id, # Set it only here **params ) # Decode response and ensure array output generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Always return an array as required by the endpoint return [{"generated_text": generated_text}] except Exception as e: print(f"Error in generation: {str(e)}") return {"error": str(e)} def preprocess(self, request): """ Prepare request for inference """ if request.content_type != "application/json": raise ValueError("Content type must be application/json") data = request.json return data def postprocess(self, data): """ Post-process model output """ return data