DrishtiSharma
commited on
Commit
•
765a4ee
1
Parent(s):
3ebdd6b
Upload 6 files
Browse files- few_shot.py +44 -0
- llm_helper.py +16 -0
- main.py +42 -0
- post_generator.py +52 -0
- preprocess.py +85 -0
- requirements.txt +6 -0
few_shot.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
class FewShotPosts:
|
6 |
+
def __init__(self, file_path="data/processed_posts.json"):
|
7 |
+
self.df = None
|
8 |
+
self.unique_tags = None
|
9 |
+
self.load_posts(file_path)
|
10 |
+
|
11 |
+
def load_posts(self, file_path):
|
12 |
+
with open(file_path, encoding="utf-8") as f:
|
13 |
+
posts = json.load(f)
|
14 |
+
self.df = pd.json_normalize(posts)
|
15 |
+
self.df['length'] = self.df['line_count'].apply(self.categorize_length)
|
16 |
+
# collect unique tags
|
17 |
+
all_tags = self.df['tags'].apply(lambda x: x).sum()
|
18 |
+
self.unique_tags = list(set(all_tags))
|
19 |
+
|
20 |
+
def get_filtered_posts(self, length, language, tag):
|
21 |
+
df_filtered = self.df[
|
22 |
+
(self.df['tags'].apply(lambda tags: tag in tags)) & # Tags contain 'Influencer'
|
23 |
+
(self.df['language'] == language) & # Language is 'English'
|
24 |
+
(self.df['length'] == length) # Line count is less than 5
|
25 |
+
]
|
26 |
+
return df_filtered.to_dict(orient='records')
|
27 |
+
|
28 |
+
def categorize_length(self, line_count):
|
29 |
+
if line_count < 5:
|
30 |
+
return "Short"
|
31 |
+
elif 5 <= line_count <= 10:
|
32 |
+
return "Medium"
|
33 |
+
else:
|
34 |
+
return "Long"
|
35 |
+
|
36 |
+
def get_tags(self):
|
37 |
+
return self.unique_tags
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
fs = FewShotPosts()
|
42 |
+
# print(fs.get_tags())
|
43 |
+
posts = fs.get_filtered_posts("Medium","Hinglish","Job Search")
|
44 |
+
print(posts)
|
llm_helper.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_groq import ChatGroq
|
2 |
+
import os
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
llm = ChatGroq(groq_api_key=os.getenv("GROQ_API_KEY"), model_name="llama-3.2-90b-text-preview")
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
response = llm.invoke("Two most important ingradient in samosa are ")
|
11 |
+
print(response.content)
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
main.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from few_shot import FewShotPosts
|
3 |
+
from post_generator import generate_post
|
4 |
+
|
5 |
+
|
6 |
+
# Options for length and language
|
7 |
+
length_options = ["Short", "Medium", "Long"]
|
8 |
+
language_options = ["English", "Hinglish"]
|
9 |
+
|
10 |
+
|
11 |
+
# Main app layout
|
12 |
+
def main():
|
13 |
+
st.subheader("LinkedIn Post Generator: Codebasics")
|
14 |
+
|
15 |
+
# Create three columns for the dropdowns
|
16 |
+
col1, col2, col3 = st.columns(3)
|
17 |
+
|
18 |
+
fs = FewShotPosts()
|
19 |
+
tags = fs.get_tags()
|
20 |
+
with col1:
|
21 |
+
# Dropdown for Topic (Tags)
|
22 |
+
selected_tag = st.selectbox("Topic", options=tags)
|
23 |
+
|
24 |
+
with col2:
|
25 |
+
# Dropdown for Length
|
26 |
+
selected_length = st.selectbox("Length", options=length_options)
|
27 |
+
|
28 |
+
with col3:
|
29 |
+
# Dropdown for Language
|
30 |
+
selected_language = st.selectbox("Language", options=language_options)
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
# Generate Button
|
35 |
+
if st.button("Generate"):
|
36 |
+
post = generate_post(selected_length, selected_language, selected_tag)
|
37 |
+
st.write(post)
|
38 |
+
|
39 |
+
|
40 |
+
# Run the app
|
41 |
+
if __name__ == "__main__":
|
42 |
+
main()
|
post_generator.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llm_helper import llm
|
2 |
+
from few_shot import FewShotPosts
|
3 |
+
|
4 |
+
few_shot = FewShotPosts()
|
5 |
+
|
6 |
+
|
7 |
+
def get_length_str(length):
|
8 |
+
if length == "Short":
|
9 |
+
return "1 to 5 lines"
|
10 |
+
if length == "Medium":
|
11 |
+
return "6 to 10 lines"
|
12 |
+
if length == "Long":
|
13 |
+
return "11 to 15 lines"
|
14 |
+
|
15 |
+
|
16 |
+
def generate_post(length, language, tag):
|
17 |
+
prompt = get_prompt(length, language, tag)
|
18 |
+
response = llm.invoke(prompt)
|
19 |
+
return response.content
|
20 |
+
|
21 |
+
|
22 |
+
def get_prompt(length, language, tag):
|
23 |
+
length_str = get_length_str(length)
|
24 |
+
|
25 |
+
prompt = f'''
|
26 |
+
Generate a LinkedIn post using the below information. No preamble.
|
27 |
+
|
28 |
+
1) Topic: {tag}
|
29 |
+
2) Length: {length_str}
|
30 |
+
3) Language: {language}
|
31 |
+
If Language is Hinglish then it means it is a mix of Hindi and English.
|
32 |
+
The script for the generated post should always be English.
|
33 |
+
'''
|
34 |
+
# prompt = prompt.format(post_topic=tag, post_length=length_str, post_language=language)
|
35 |
+
|
36 |
+
examples = few_shot.get_filtered_posts(length, language, tag)
|
37 |
+
|
38 |
+
if len(examples) > 0:
|
39 |
+
prompt += "4) Use the writing style as per the following examples."
|
40 |
+
|
41 |
+
for i, post in enumerate(examples):
|
42 |
+
post_text = post['text']
|
43 |
+
prompt += f'\n\n Example {i+1}: \n\n {post_text}'
|
44 |
+
|
45 |
+
if i == 1: # Use max two samples
|
46 |
+
break
|
47 |
+
|
48 |
+
return prompt
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
print(generate_post("Medium", "English", "Mental Health"))
|
preprocess.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from llm_helper import llm
|
3 |
+
from langchain_core.prompts import PromptTemplate
|
4 |
+
from langchain_core.output_parsers import JsonOutputParser
|
5 |
+
from langchain_core.exceptions import OutputParserException
|
6 |
+
|
7 |
+
|
8 |
+
def process_posts(raw_file_path, processed_file_path=None):
|
9 |
+
with open(raw_file_path, encoding='utf-8') as file:
|
10 |
+
posts = json.load(file)
|
11 |
+
enriched_posts = []
|
12 |
+
for post in posts:
|
13 |
+
metadata = extract_metadata(post['text'])
|
14 |
+
post_with_metadata = post | metadata
|
15 |
+
enriched_posts.append(post_with_metadata)
|
16 |
+
|
17 |
+
unified_tags = get_unified_tags(enriched_posts)
|
18 |
+
for post in enriched_posts:
|
19 |
+
current_tags = post['tags']
|
20 |
+
new_tags = {unified_tags[tag] for tag in current_tags}
|
21 |
+
post['tags'] = list(new_tags)
|
22 |
+
|
23 |
+
with open(processed_file_path, encoding='utf-8', mode="w") as outfile:
|
24 |
+
json.dump(enriched_posts, outfile, indent=4)
|
25 |
+
|
26 |
+
|
27 |
+
def extract_metadata(post):
|
28 |
+
template = '''
|
29 |
+
You are given a LinkedIn post. You need to extract number of lines, language of the post and tags.
|
30 |
+
1. Return a valid JSON. No preamble.
|
31 |
+
2. JSON object should have exactly three keys: line_count, language and tags.
|
32 |
+
3. tags is an array of text tags. Extract maximum two tags.
|
33 |
+
4. Language should be English or Hinglish (Hinglish means hindi + english)
|
34 |
+
|
35 |
+
Here is the actual post on which you need to perform this task:
|
36 |
+
{post}
|
37 |
+
'''
|
38 |
+
|
39 |
+
pt = PromptTemplate.from_template(template)
|
40 |
+
chain = pt | llm
|
41 |
+
response = chain.invoke(input={"post": post})
|
42 |
+
|
43 |
+
try:
|
44 |
+
json_parser = JsonOutputParser()
|
45 |
+
res = json_parser.parse(response.content)
|
46 |
+
except OutputParserException:
|
47 |
+
raise OutputParserException("Context too big. Unable to parse jobs.")
|
48 |
+
return res
|
49 |
+
|
50 |
+
|
51 |
+
def get_unified_tags(posts_with_metadata):
|
52 |
+
unique_tags = set()
|
53 |
+
# Loop through each post and extract the tags
|
54 |
+
for post in posts_with_metadata:
|
55 |
+
unique_tags.update(post['tags']) # Add the tags to the set
|
56 |
+
|
57 |
+
unique_tags_list = ','.join(unique_tags)
|
58 |
+
|
59 |
+
template = '''I will give you a list of tags. You need to unify tags with the following requirements,
|
60 |
+
1. Tags are unified and merged to create a shorter list.
|
61 |
+
Example 1: "Jobseekers", "Job Hunting" can be all merged into a single tag "Job Search".
|
62 |
+
Example 2: "Motivation", "Inspiration", "Drive" can be mapped to "Motivation"
|
63 |
+
Example 3: "Personal Growth", "Personal Development", "Self Improvement" can be mapped to "Self Improvement"
|
64 |
+
Example 4: "Scam Alert", "Job Scam" etc. can be mapped to "Scams"
|
65 |
+
2. Each tag should be follow title case convention. example: "Motivation", "Job Search"
|
66 |
+
3. Output should be a JSON object, No preamble
|
67 |
+
3. Output should have mapping of original tag and the unified tag.
|
68 |
+
For example: {{"Jobseekers": "Job Search", "Job Hunting": "Job Search", "Motivation": "Motivation}}
|
69 |
+
|
70 |
+
Here is the list of tags:
|
71 |
+
{tags}
|
72 |
+
'''
|
73 |
+
pt = PromptTemplate.from_template(template)
|
74 |
+
chain = pt | llm
|
75 |
+
response = chain.invoke(input={"tags": str(unique_tags_list)})
|
76 |
+
try:
|
77 |
+
json_parser = JsonOutputParser()
|
78 |
+
res = json_parser.parse(response.content)
|
79 |
+
except OutputParserException:
|
80 |
+
raise OutputParserException("Context too big. Unable to parse jobs.")
|
81 |
+
return res
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
process_posts("data/raw_posts.json", "data/processed_posts.json")
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.35.0
|
2 |
+
langchain==0.2.14
|
3 |
+
langchain-core==0.2.39
|
4 |
+
langchain-community==0.2.12
|
5 |
+
langchain_groq==0.1.9
|
6 |
+
pandas==2.0.2
|