AreesaAshfaq's picture
Update app.py
278f16c verified
raw
history blame
No virus
1.25 kB
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import json
from datasets import load_dataset
# Load the dataset
dataset = load_dataset('json', data_files={'data': 'https://huggingface.co/spaces/AreesaAshfaq/ContentGenerator/blob/main/article-2020.json'})['data']
# Load the model and tokenizer
tokenizer = RagTokenizer.from_pretrained('facebook/rag-sequence-nq')
retriever = RagRetriever.from_pretrained('facebook/rag-sequence-nq')
model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq')
# Define the article generation function
def generate_article(query, dataset):
inputs = tokenizer(query, return_tensors="pt")
retriever.set_ctx_data(dataset)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=200
)
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Streamlit UI
st.title('Tech Article Generator')
topics = [item['title'] for item in dataset]
selected_topic = st.selectbox('Select a topic', topics)
if st.button('Generate Article'):
query = selected_topic
article = generate_article(query, dataset)
st.write(article)