|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
from transformers import pipeline |
|
import streamlit as st |
|
import requests |
|
|
|
def get_story(image_path): |
|
model_name = st.selectbox('Select the Model', ['alpaca-lora', 'flan-t5-base']) |
|
image_to_text = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") |
|
caption = image_to_text(image_path) |
|
caption = caption[0]['generated_text'] |
|
st.write(f"Generated Caption: {caption}") |
|
input_string = f"""Question: Generate 100 words story on this text |
|
'{caption}' Answer:""" |
|
if model_name == 'flan-t5-base': |
|
from transformers import T5ForConditionalGeneration, AutoTokenizer |
|
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto", load_in_8bit=True) |
|
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") |
|
inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cpu") |
|
outputs = model.generate(inputs, max_length=1000) |
|
outputs = tokenizer.decode(outputs[0]) |
|
else: |
|
|
|
response = requests.post("https://tloen-alpaca-lora.hf.space/run/predict", json={ |
|
"data": [ |
|
"Write a story about this image caption", |
|
caption, |
|
0.1, |
|
0.75, |
|
40, |
|
4, |
|
128, |
|
] |
|
}).json() |
|
|
|
data = response["data"] |
|
outputs = data[0] |
|
return outputs |
|
|