import os from transformers import pipeline from langchain_huggingface import HuggingFaceEndpoint from langchain_core.prompts import PromptTemplate from PIL import Image class StoryGenerator: def __init__(self, image_model="Salesforce/blip-image-captioning-base"): self.image_model = image_model self.image_to_text = pipeline("image-to-text", model=self.image_model) self.text_models = { "Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2", "FLAN-T5": "google/flan-t5-large", "MPT-7B": "mosaicml/mpt-7b-instruct", "Falcon-7B": "tiiuae/falcon-7b-instruct" } self.prompt_template = PromptTemplate.from_template(""" You are a kids story writer. Provide a coherent story for kids using this simple instruction: {scenario}. The story should have a clear beginning, middle, and end. The story should be interesting and engaging for kids. The story should be maximum 200 words long. Do not include any adult or polemic content. Story: """) def get_llm(self, model_name): return HuggingFaceEndpoint( repo_id=self.text_models[model_name], temperature=0.5, streaming=True ) def img2txt(self, image_path): """Convert image to text using Hugging Face pipeline.""" text = self.image_to_text(image_path)[0]["generated_text"] print(f"Image caption: {text}") return text def generate_story(self, scenario, model_name): """Generate a story using image captioning and language model.""" llm = self.get_llm(model_name) story = self.prompt_template | llm generated_story = story.invoke( input={"scenario": scenario} ).strip().rstrip('').strip() return generated_story def generate_story_from_image(self, image, model_name): """Generate a story from an image.""" print(f"Received image: {image}") print(f"Image type: {type(image)}") if isinstance(image, str): # If it's a file path temp_image_path = image else: # If it's a PIL Image object temp_image_path = "temp_image.jpg" image.save(temp_image_path) try: scenario = self.img2txt(temp_image_path) story = self.generate_story(scenario, model_name) finally: if temp_image_path != image and os.path.exists(temp_image_path): os.remove(temp_image_path) return story # Example usage if __name__ == "__main__": generator = StoryGenerator() example_image_path = os.path.join("assets", "image.jpg") if os.path.exists(example_image_path): story = generator.generate_story_from_image(example_image_path, "Mistral-7B") print(story) else: print(f"Example image not found at {example_image_path}")