File size: 4,434 Bytes
40190c3
41d1bc5
 
 
 
45b4c77
41d1bc5
 
 
 
 
 
 
 
 
 
 
 
229e7eb
41d1bc5
229e7eb
41d1bc5
 
229e7eb
41d1bc5
 
 
 
 
 
 
 
 
229e7eb
41d1bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f38260
41d1bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecf217a
 
41d1bc5
 
 
 
 
 
40190c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import streamlit as st
import openai
import os
import sys
import argparse
sys.path.append('./lats')
from lats_main import lats_main

st.set_page_config(layout="wide")

# Initialize session state variables if they don't exist.
if 'response_content' not in st.session_state:
    st.session_state.response_content = None

# Creating main columns for the chat and runtime notifications
chat_col = st.container()

chat_col.title("CodeLATS")
description = """This demo is an implementation of Language Agent Tree Search (LATS) (https://arxiv.org/abs/2310.04406) with Samba-1 in the backend. Thank you to the original authors of demo on which this is based from [Lapis Labs](https://lapis.rocks/)

Listed below is an example programming problem (https://leetcode.com/problems/median-of-two-sorted-arrays/description/) to get started with. 

```python
Given two sorted arrays `nums1` and `nums2` of size `m` and `n` respectively, return **the median** of the two sorted arrays. The overall run time complexity should be `O(log (m+n))`. **Example 1:** **Input:** nums1 = \[1,3\], nums2 = \[2\] **Output:** 2.00000 **Explanation:** merged array = \[1,2,3\] and median is 2. **Example 2:** **Input:** nums1 = \[1,2\], nums2 = \[3,4\] **Output:** 2.50000 **Explanation:** merged array = \[1,2,3,4\] and median is (2 + 3) / 2 = 2.5. **Constraints:** * `nums1.length == m` * `nums2.length == n` * `0 <= m <= 1000` * `0 <= n <= 1000` * `1 <= m + n <= 2000` * `-106 <= nums1[i], nums2[i] <= 106`
```
"""

chat_col.markdown(description)
sidebar = st.sidebar
# Runtime Section
runtime_container = st.container()

# Parameters Section
sidebar.title("From SambaNova Systems")
parameters_section = sidebar.expander("Parameters", expanded=False)
tree_width = parameters_section.number_input("Tree Width", min_value=1, max_value=5, value=1)
tree_depth = parameters_section.number_input("Tree Depth", min_value=1, max_value=8, value=3)
iterations = parameters_section.number_input("Iterations", min_value=1, max_value=4, value=2)
sidebar.markdown('<hr style="margin-top: 0.5rem; margin-bottom: 0.5rem;">', unsafe_allow_html=True)

with sidebar:
    runtime_container = st.container()
    runtime_container.empty()

runtime_messages = []

def make_args(instruction, tree_depth, tree_width, iterations):
    parser = argparse.ArgumentParser()

    parser.add_argument("--strategy", default="mcts", help="Strategy to use")
    parser.add_argument("--language", default="py", help="Programming language")
    parser.add_argument("--max_iters", default=iterations, help="Maximum iterations")
    parser.add_argument("--instruction", default=instruction, help="Instruction text")
    parser.add_argument("--verbose", action="store_true", help="Verbose output")
    parser.add_argument("--is_leetcode", action='store_true',
                        help="To run the leetcode benchmark")  # Temporary
    parser.add_argument("--n_samples", type=int,
                        help="The number of nodes added during expansion", default=tree_width)
    parser.add_argument("--depth", type=int,
                        help="Tree depth", default=tree_depth)
    args = parser.parse_args()
    return args

def run_querry():
    if user_input:
        # Create a new container for each subsequent message
        runtime_container.write("Initiating process...")

        # Make it so that prints go to runtime_container writes instead
        old_stdout = sys.stdout
        sys.stdout = runtime_container

        with chat_col:

            with st.spinner('Running...'):
                args = make_args(user_input, tree_depth, tree_width, iterations)
                setattr(args, 'model', 'samba')
                # main call
                response = lats_main(args)

        sys.stdout = old_stdout
        runtime_container.write("Response fetched.")
        chat_col.markdown('<hr style="margin-top: 0.5rem; margin-bottom: 0.5rem;">', unsafe_allow_html=True)
        chat_col.write(f"```python\n{response} \n")

        return response

# User input section at the bottom of the page
with chat_col:
    user_input = st.text_area("Enter your message here:", placeholder="Type your message here...", label_visibility="collapsed")
    button = st.button("Send")

    if button:
        fail = False
    
        if user_input == "":
            st.warning("Missing a coding problem")
            fail = True
        
        if (not fail):
            run_querry()