StorySeed / model.py
anasrz's picture
Create model.py
85fbc24
raw
history blame
1.46 kB
import torch
from torch.utils.data import Dataset
from transformers import pipeline
import streamlit as st
import requests
def get_story(image_path):
model_name = st.selectbox('Select the Model', ['alpaca-lora', 'flan-t5-base'])
image_to_text = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
caption = image_to_text(image_path)
caption = caption[0]['generated_text']
st.write(f"Generated Caption: {caption}")
input_string = f"""Question: Generate 100 words story on this text
'{caption}' Answer:"""
if model_name == 'flan-t5-base':
from transformers import T5ForConditionalGeneration, AutoTokenizer
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto", load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cpu")
outputs = model.generate(inputs, max_length=1000)
outputs = tokenizer.decode(outputs[0])
else:
response = requests.post("https://tloen-alpaca-lora.hf.space/run/predict", json={
"data": [
"Write a story about this image caption",
caption,
0.1,
0.75,
40,
4,
128,
]
}).json()
data = response["data"]
outputs = data[0]
return outputs