File size: 2,689 Bytes
9d8add7
 
 
 
 
 
 
 
 
 
265c07a
9d8add7
 
 
 
 
 
 
 
 
 
 
 
 
3253a7e
9d8add7
 
 
 
 
 
 
 
 
3253a7e
 
 
 
 
9d8add7
 
 
3253a7e
9d8add7
 
 
 
 
 
 
3253a7e
 
 
9d8add7
 
 
3253a7e
9d8add7
265c07a
9d8add7
 
265c07a
9d8add7
265c07a
9d8add7
3253a7e
 
 
 
 
9d8add7
265c07a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import requests
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.prompts import (
    PromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    ChatPromptTemplate,
)
from config import app_config
import mongo_utils as mongo


def __image2text(image):
    """Generates a short description of the image"""
    headers = {"Authorization": app_config.HF_TOKEN}
    try:
        response = requests.post(app_config.I2T_API_URL, headers=headers, data=image)
        response = response.json()[0]["generated_text"]
    except Exception as e:
        print(e)
    return response


def __text2story(image_desc, genre, style, word_count, creativity):
    """ "Generates a short story based on image description text prompt"""
    ## chat LLM model
    story_model = ChatOpenAI(
        model="gpt-3.5-turbo",
        openai_api_key=app_config.OPENAI_KEY,
        temperature=creativity,
    )
    ## chat message prompts
    sys_prompt = PromptTemplate(
        template="""You are an expert story writer, write a maximum of {word_count} 
        words long story in {genre} genre in {style} writing style, based on the user 
        provided story-context.
        """,
        input_variables=["word_count", "genre", "style"],
    )
    system_msg_prompt = SystemMessagePromptTemplate(prompt=sys_prompt)
    human_prompt = PromptTemplate(
        template="story-context: {context}", input_variables=["context"]
    )
    human_msg_prompt = HumanMessagePromptTemplate(prompt=human_prompt)
    chat_prompt = ChatPromptTemplate.from_messages(
        [system_msg_prompt, human_msg_prompt]
    )
    ## LLM chain
    story_chain = LLMChain(llm=story_model, prompt=chat_prompt)
    response = story_chain.run(
        genre=genre, style=style, word_count=word_count, context=image_desc
    )
    return response


def generate_story(image_file, genre, style, word_count, creativity):
    """Generates a story given an image"""
    # read image as bytes arrayS
    with open(image_file, "rb") as f:
        input_image = f.read()
    # generate caption for image
    image_desc = __image2text(image=input_image)
    # generate story from caption
    story = __text2story(
        image_desc=image_desc,
        genre=genre,
        style=style,
        word_count=word_count,
        creativity=creativity,
    )
    # increment the openai access counter and compute count stats
    mongo.increment_curr_access_count()
    max_count = app_config.openai_max_access_count
    curr_count = app_config.openai_curr_access_count
    available_count = max_count - curr_count
    return story, max_count, curr_count, available_count