AreesaAshfaq commited on
Commit
0eaa49c
1 Parent(s): fd66d73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -29
app.py CHANGED
@@ -1,34 +1,47 @@
1
- import streamlit as st
2
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
  import json
4
- from datasets import load_dataset
 
 
 
 
 
 
5
 
6
- # Load the dataset
7
- dataset = load_dataset('json', data_files={'data': 'https://huggingface.co/spaces/AreesaAshfaq/ContentGenerator/blob/main/article-2020.json'})['data']
 
8
 
9
- # Load the model and tokenizer
10
- tokenizer = RagTokenizer.from_pretrained('facebook/rag-sequence-nq')
11
- retriever = RagRetriever.from_pretrained('facebook/rag-sequence-nq')
12
- model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq')
13
 
14
- # Define the article generation function
15
- def generate_article(query, dataset):
 
 
 
 
16
  inputs = tokenizer(query, return_tensors="pt")
17
- retriever.set_ctx_data(dataset)
18
- generated_ids = model.generate(
19
- input_ids=inputs["input_ids"],
20
- attention_mask=inputs["attention_mask"],
21
- max_length=200
22
- )
23
- return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
24
-
25
- # Streamlit UI
26
- st.title('Tech Article Generator')
27
-
28
- topics = [item['title'] for item in dataset]
29
- selected_topic = st.selectbox('Select a topic', topics)
30
-
31
- if st.button('Generate Article'):
32
- query = selected_topic
33
- article = generate_article(query, dataset)
34
- st.write(article)
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import requests
3
  import json
4
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
5
+ import streamlit as st
6
+
7
+ # Fetch the dataset
8
+ url = 'https://huggingface.co/spaces/AreesaAshfaq/ContentGenerator/raw/main/article-2020.json'
9
+ response = requests.get(url)
10
+ data = response.json()
11
 
12
+ # Convert to DataFrame
13
+ df = pd.DataFrame(data)
14
+ #print(df.head())
15
 
 
 
 
 
16
 
17
+ # Initialize the tokenizer, retriever, and model
18
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
19
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq")
20
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
21
+
22
+ def generate_article(query):
23
  inputs = tokenizer(query, return_tensors="pt")
24
+ input_ids = inputs["input_ids"]
25
+
26
+ # Generate articles
27
+ with torch.no_grad():
28
+ outputs = model.generate(input_ids=input_ids, num_beams=4, max_length=200)
29
+
30
+ article = tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+ return article
32
+
33
+
34
+ # Set up the Streamlit app
35
+ st.title('Tech Article Generator using RAG')
36
+
37
+ # User input for query
38
+ query = st.text_input("Enter a topic or keyword:")
39
+
40
+ if st.button("Generate Article"):
41
+ if query:
42
+ # Generate article
43
+ article = generate_article(query)
44
+ st.subheader("Generated Article:")
45
+ st.write(article)
46
+ else:
47
+ st.error("Please enter a query to generate an article.")