{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "5339272e", "metadata": { "execution": { "iopub.execute_input": "2023-03-22T17:00:42.285169Z", "iopub.status.busy": "2023-03-22T17:00:42.284973Z", "iopub.status.idle": "2023-03-22T17:00:42.291455Z", "shell.execute_reply": "2023-03-22T17:00:42.290871Z" }, "tags": [ "hide_inp" ] }, "outputs": [], "source": [ "desc = \"\"\"\n", "### Book QA\n", "\n", "Chain that does question answering with Hugging Face embeddings. [[Code](https://github.com/srush/MiniChain/blob/main/examples/gatsby.py)]\n", "\n", "(Adapted from the [LlamaIndex example](https://github.com/jerryjliu/gpt_index/blob/main/examples/gatsby/TestGatsby.ipynb).)\n", "\"\"\"" ] }, { "cell_type": "markdown", "id": "f966010d", "metadata": {}, "source": [ "$" ] }, { "cell_type": "code", "execution_count": 2, "id": "5391c476", "metadata": { "execution": { "iopub.execute_input": "2023-03-22T17:00:42.293927Z", "iopub.status.busy": "2023-03-22T17:00:42.293738Z", "iopub.status.idle": "2023-03-22T17:00:43.695402Z", "shell.execute_reply": "2023-03-22T17:00:43.694722Z" } }, "outputs": [], "source": [ "import datasets\n", "import numpy as np\n", "from minichain import prompt, show, HuggingFaceEmbed, OpenAI" ] }, { "cell_type": "markdown", "id": "5b1b2a82", "metadata": {}, "source": [ "Load data with embeddings (computed beforehand)" ] }, { "cell_type": "code", "execution_count": 3, "id": "a54cf84e", "metadata": { "execution": { "iopub.execute_input": "2023-03-22T17:00:43.698121Z", "iopub.status.busy": "2023-03-22T17:00:43.697792Z", "iopub.status.idle": "2023-03-22T17:00:43.730349Z", "shell.execute_reply": "2023-03-22T17:00:43.729747Z" }, "lines_to_next_cell": 1 }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "86d0a2ceb7ad4f99978e37c2719f2960", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Dataset({\n", " features: ['passages', 'embeddings'],\n", " num_rows: 52\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gatsby = datasets.load_from_disk(\"gatsby\")\n", "gatsby.add_faiss_index(\"embeddings\")" ] }, { "cell_type": "markdown", "id": "a9a08061", "metadata": {}, "source": [ "Fast KNN retieval prompt" ] }, { "cell_type": "code", "execution_count": 4, "id": "3e69cd6b", "metadata": { "execution": { "iopub.execute_input": "2023-03-22T17:00:43.735462Z", "iopub.status.busy": "2023-03-22T17:00:43.735139Z", "iopub.status.idle": "2023-03-22T17:00:43.738420Z", "shell.execute_reply": "2023-03-22T17:00:43.737964Z" }, "lines_to_next_cell": 1 }, "outputs": [], "source": [ "@prompt(HuggingFaceEmbed(\"sentence-transformers/all-mpnet-base-v2\"))\n", "def get_neighbors(model, inp, k=1):\n", " embedding = model(inp)\n", " res = olympics.get_nearest_examples(\"embeddings\", np.array(embedding), k)\n", " return res.examples[\"passages\"]" ] }, { "cell_type": "code", "execution_count": 5, "id": "14e22d0a", "metadata": { "execution": { "iopub.execute_input": "2023-03-22T17:00:43.740824Z", "iopub.status.busy": "2023-03-22T17:00:43.740339Z", "iopub.status.idle": "2023-03-22T17:00:43.743342Z", "shell.execute_reply": "2023-03-22T17:00:43.742905Z" }, "lines_to_next_cell": 1 }, "outputs": [], "source": [ "@prompt(OpenAI(),\n", " template_file=\"gatsby.pmpt.tpl\")\n", "def ask(model, query, neighbors):\n", " return model(dict(question=query, docs=neighbors))" ] }, { "cell_type": "code", "execution_count": 6, "id": "bfca0bea", "metadata": { "execution": { "iopub.execute_input": "2023-03-22T17:00:43.745377Z", "iopub.status.busy": "2023-03-22T17:00:43.745056Z", "iopub.status.idle": "2023-03-22T17:00:43.747768Z", "shell.execute_reply": "2023-03-22T17:00:43.747352Z" }, "lines_to_next_cell": 2 }, "outputs": [], "source": [ "def gatsby(query):\n", " n = get_neighbors(query)\n", " return ask(query, n)" ] }, { "cell_type": "markdown", "id": "159b0b85", "metadata": { "lines_to_next_cell": 2 }, "source": [ "$" ] }, { "cell_type": "code", "execution_count": 7, "id": "8e3f74d0", "metadata": { "execution": { "iopub.execute_input": "2023-03-22T17:00:43.749935Z", "iopub.status.busy": "2023-03-22T17:00:43.749750Z", "iopub.status.idle": "2023-03-22T17:00:44.094814Z", "shell.execute_reply": "2023-03-22T17:00:44.094179Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7861\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "