Spaces:
Build error
Build error
import pandas as pd | |
import requests | |
import json | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
import streamlit as st | |
# Fetch the dataset | |
url = 'https://huggingface.co/spaces/AreesaAshfaq/ContentGenerator/raw/main/article-2020.json' | |
response = requests.get(url) | |
data = response.json() | |
# Convert to DataFrame | |
df = pd.DataFrame(data) | |
#print(df.head()) | |
# Initialize the tokenizer, retriever, and model | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq") | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq") | |
def generate_article(query): | |
inputs = tokenizer(query, return_tensors="pt") | |
input_ids = inputs["input_ids"] | |
# Generate articles | |
with torch.no_grad(): | |
outputs = model.generate(input_ids=input_ids, num_beams=4, max_length=200) | |
article = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return article | |
# Set up the Streamlit app | |
st.title('Tech Article Generator using RAG') | |
# User input for query | |
query = st.text_input("Enter a topic or keyword:") | |
if st.button("Generate Article"): | |
if query: | |
# Generate article | |
article = generate_article(query) | |
st.subheader("Generated Article:") | |
st.write(article) | |
else: | |
st.error("Please enter a query to generate an article.") |