Spaces:
Runtime error
Runtime error
File size: 5,206 Bytes
d0cb1cf c3a9149 a1c7a27 fd9d69a d0cb1cf c3a9149 d0cb1cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# Import necessary libraries
import nest_asyncio
import gradio as gr
import requests
from huggingface_hub import InferenceClient
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.document_loaders import TextLoader
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.document_loaders import AsyncChromiumLoader
from langchain.document_loaders import TextLoader
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.document_loaders import AsyncChromiumLoader
# Apply nest_asyncio for asynchronous operations in environments like Jupyter notebooks
nest_asyncio.apply()
# Initialize the InferenceClient with the specified model
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1"
)
# Set up a prompt template for the model (customize as needed)
prompt_template = PromptTemplate()
# Define the list of articles to index
articles = [
"https://www.fantasypros.com/2023/11/rival-fantasy-nfl-week-10/",
"https://www.fantasypros.com/2023/11/5-stats-to-know-before-setting-your-fantasy-lineup-week-10/",
"https://www.fantasypros.com/2023/11/nfl-week-10-sleeper-picks-player-predictions-2023/",
"https://www.fantasypros.com/2023/11/nfl-dfs-week-10-stacking-advice-picks-2023-fantasy-football/",
"https://www.fantasypros.com/2023/11/players-to-buy-low-sell-high-trade-advice-2023-fantasy-football/"
]
# Scrapes the blogs above
loader = AsyncChromiumLoader(articles)
docs = loader.load()
# Converts HTML to plain text
html2text = Html2TextTransformer()
docs_transformed = html2text.transform_documents(docs)
# Chunk text
text_splitter = CharacterTextSplitter(chunk_size=100,
chunk_overlap=10)
chunked_documents = text_splitter.split_documents(docs_transformed)
# Load chunked documents into the FAISS index
db = FAISS.from_documents(chunked_documents,
HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2'))
retriever = db.as_retriever()
# Create the RAG chain by combining the language model with the retriever
rag_chain = ({"context": retriever} | LLMChain)
# Define the generation function for the Gradio interface
def generate(
prompt, history, temperature=0.7, max_new_tokens=256, top_p=0.95, repetition_penalty=1.1,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = "<s>"
for user_prompt, bot_response in history:
formatted_prompt += f"[INST] {user_prompt} [/INST]"
formatted_prompt += f" {bot_response}</s> "
formatted_prompt += f"[INST] {prompt} [/INST]"
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
# Define additional input components for the Gradio interface
additional_inputs = [
gr.Slider(
label="Temperature",
value=0.7,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=1024,
step=64,
interactive=True,
info="The maximum number of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.95,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.1,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
# Define CSS for styling the Gradio interface
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
# Create the Gradio interface with the chat component
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>Mistral 7B Instruct<h1><center>")
gr.HTML("<h3><center>In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1'>Mistral-7B-Instruct</a> model. 📜<h3><center>")
gr.HTML("<h3><center>Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. 📚<h3><center>")
gr.ChatInterface(
generate,
additional_inputs=additional_inputs,
examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."]],
)
# Launch the Gradio interface with debugging enabled
demo.queue().launch(debug=True)
|