Spaces:
Build error
Build error
import streamlit as st | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
import json | |
from datasets import load_dataset | |
# Load the dataset | |
dataset = load_dataset('json', data_files={'data': 'https://huggingface.co/spaces/AreesaAshfaq/ContentGenerator/blob/main/article-2020.json'})['data'] | |
# Load the model and tokenizer | |
tokenizer = RagTokenizer.from_pretrained('facebook/rag-sequence-nq') | |
retriever = RagRetriever.from_pretrained('facebook/rag-sequence-nq') | |
model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq') | |
# Define the article generation function | |
def generate_article(query, dataset): | |
inputs = tokenizer(query, return_tensors="pt") | |
retriever.set_ctx_data(dataset) | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_length=200 | |
) | |
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# Streamlit UI | |
st.title('Tech Article Generator') | |
topics = [item['title'] for item in dataset] | |
selected_topic = st.selectbox('Select a topic', topics) | |
if st.button('Generate Article'): | |
query = selected_topic | |
article = generate_article(query, dataset) | |
st.write(article) |