Spaces:
Runtime error
Runtime error
import streamlit as st | |
import openai | |
import os | |
import sys | |
import argparse | |
sys.path.append('./lats') | |
from lats_main import lats_main | |
st.set_page_config(layout="wide") | |
# Initialize session state variables if they don't exist. | |
if 'response_content' not in st.session_state: | |
st.session_state.response_content = None | |
# Creating main columns for the chat and runtime notifications | |
chat_col = st.container() | |
chat_col.title("CodeLATS") | |
description = """This demo is an implementation of Language Agent Tree Search (LATS) (https://arxiv.org/abs/2310.04406) with Samba-1 in the backend. Thank you to the original authors of demo on which this is based from [Lapis Labs](https://lapis.rocks/) | |
Listed below is an example programming problem (https://leetcode.com/problems/median-of-two-sorted-arrays/description/) to get started with. | |
```python | |
Given two sorted arrays `nums1` and `nums2` of size `m` and `n` respectively, return **the median** of the two sorted arrays. The overall run time complexity should be `O(log (m+n))`. **Example 1:** **Input:** nums1 = \[1,3\], nums2 = \[2\] **Output:** 2.00000 **Explanation:** merged array = \[1,2,3\] and median is 2. **Example 2:** **Input:** nums1 = \[1,2\], nums2 = \[3,4\] **Output:** 2.50000 **Explanation:** merged array = \[1,2,3,4\] and median is (2 + 3) / 2 = 2.5. **Constraints:** * `nums1.length == m` * `nums2.length == n` * `0 <= m <= 1000` * `0 <= n <= 1000` * `1 <= m + n <= 2000` * `-106 <= nums1[i], nums2[i] <= 106` | |
``` | |
""" | |
chat_col.markdown(description) | |
sidebar = st.sidebar | |
# Runtime Section | |
runtime_container = st.container() | |
# Parameters Section | |
sidebar.title("From SambaNova Systems") | |
parameters_section = sidebar.expander("Parameters", expanded=False) | |
tree_width = parameters_section.number_input("Tree Width", min_value=1, max_value=5, value=1) | |
tree_depth = parameters_section.number_input("Tree Depth", min_value=1, max_value=8, value=3) | |
iterations = parameters_section.number_input("Iterations", min_value=1, max_value=4, value=2) | |
sidebar.markdown('<hr style="margin-top: 0.5rem; margin-bottom: 0.5rem;">', unsafe_allow_html=True) | |
with sidebar: | |
runtime_container = st.container() | |
runtime_container.empty() | |
runtime_messages = [] | |
def make_args(instruction, tree_depth, tree_width, iterations): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--strategy", default="mcts", help="Strategy to use") | |
parser.add_argument("--language", default="py", help="Programming language") | |
parser.add_argument("--max_iters", default=iterations, help="Maximum iterations") | |
parser.add_argument("--instruction", default=instruction, help="Instruction text") | |
parser.add_argument("--verbose", action="store_true", help="Verbose output") | |
parser.add_argument("--is_leetcode", action='store_true', | |
help="To run the leetcode benchmark") # Temporary | |
parser.add_argument("--n_samples", type=int, | |
help="The number of nodes added during expansion", default=tree_width) | |
parser.add_argument("--depth", type=int, | |
help="Tree depth", default=tree_depth) | |
args = parser.parse_args() | |
return args | |
def run_querry(): | |
if user_input: | |
# Create a new container for each subsequent message | |
runtime_container.write("Initiating process...") | |
# Make it so that prints go to runtime_container writes instead | |
old_stdout = sys.stdout | |
sys.stdout = runtime_container | |
with chat_col: | |
with st.spinner('Running...'): | |
args = make_args(user_input, tree_depth, tree_width, iterations) | |
setattr(args, 'model', 'samba') | |
# main call | |
response = lats_main(args) | |
sys.stdout = old_stdout | |
runtime_container.write("Response fetched.") | |
chat_col.markdown('<hr style="margin-top: 0.5rem; margin-bottom: 0.5rem;">', unsafe_allow_html=True) | |
chat_col.write(f"```python\n{response} \n") | |
return response | |
# User input section at the bottom of the page | |
with chat_col: | |
user_input = st.text_area("Enter your message here:", placeholder="Type your message here...", label_visibility="collapsed") | |
button = st.button("Send") | |
if button: | |
fail = False | |
if user_input == "": | |
st.warning("Missing a coding problem") | |
fail = True | |
if (not fail): | |
run_querry() | |