Spaces:
Runtime error
Runtime error
import sys | |
import os | |
import gradio as gr | |
from gradio.themes.utils import sizes | |
from text_generation import Client | |
# todo: remove and replace by the actual js file instead | |
from share_btn import (share_js) | |
from utils import ( | |
get_file_as_string, | |
get_sections, | |
get_url_from_env_or_default_path, | |
preview | |
) | |
from constants import ( | |
DEFAULT_STARCODER_API_PATH, | |
DEFAULT_STARCODER_BASE_API_PATH, | |
FIM_MIDDLE, | |
FIM_PREFIX, | |
FIM_SUFFIX, | |
END_OF_TEXT, | |
MIN_TEMPERATURE, | |
) | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
# Gracefully exit the app if the HF_TOKEN is not set, | |
# printing to system `errout` the error (instead of raising an exception) | |
# and the expected behavior | |
if not HF_TOKEN: | |
ERR_MSG = """ | |
Please set the HF_TOKEN environment variable with your Hugging Face API token. | |
You can get one by signing up at https://huggingface.co/join and then visiting | |
https://huggingface.co/settings/tokens.""" | |
print(ERR_MSG, file=sys.stderr) | |
# gr.errors.GradioError(ERR_MSG) | |
# gr.close_all(verbose=False) | |
sys.exit(1) | |
API_URL = get_url_from_env_or_default_path("STARCODER_API", DEFAULT_STARCODER_API_PATH) | |
API_URL_BASE = get_url_from_env_or_default_path("STARCODER_BASE_API", DEFAULT_STARCODER_BASE_API_PATH) | |
preview("StarCoder Model's URL", API_URL) | |
preview("StarCoderBase Model's URL", API_URL_BASE) | |
preview("HF Token", HF_TOKEN, ofuscate=True) | |
DEFAULT_PORT = 7860 | |
FIM_INDICATOR = "<FILL_HERE>" | |
# Loads the whole content of the formats.md file | |
# and stores it into the FORMATS variable | |
STATIC_PATH = "static" | |
FORMATS = get_file_as_string("formats.md", path=STATIC_PATH) | |
CSS = get_file_as_string("styles.css", path=STATIC_PATH) | |
community_icon_svg = get_file_as_string("community_icon.svg", path=STATIC_PATH) | |
loading_icon_svg = get_file_as_string("loading_icon.svg", path=STATIC_PATH) | |
# todo: evaluate making STATIC_PATH the default path instead of the current one | |
README = get_file_as_string("README.md") | |
# Slicing the different sections from the README | |
readme_sections = get_sections(README, "---") | |
manifest, description, disclaimer = readme_sections[:3] | |
theme = gr.themes.Monochrome( | |
primary_hue="indigo", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
radius_size=sizes.radius_sm, | |
font=[ | |
gr.themes.GoogleFont("Rubik"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
], | |
text_size=sizes.text_lg, | |
) | |
HEADERS = { | |
"Authorization": f"Bearer {HF_TOKEN}", | |
} | |
client = Client(API_URL, headers = HEADERS) | |
client_base = Client(API_URL_BASE, headers = HEADERS) | |
def generate(prompt, | |
temperature = 0.9, | |
max_new_tokens = 256, | |
top_p = 0.95, | |
repetition_penalty = 1.0, | |
version = "StarCoder", | |
): | |
temperature = min(float(temperature), MIN_TEMPERATURE) | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature = temperature, | |
max_new_tokens = max_new_tokens, | |
top_p = top_p, | |
repetition_penalty = repetition_penalty, | |
do_sample = True, | |
seed = 42, | |
) | |
if fim_mode := FIM_INDICATOR in prompt: | |
try: | |
prefix, suffix = prompt.split(FIM_INDICATOR) | |
except Exception as err: | |
print(str(err)) | |
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!") from err | |
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" | |
model_client = client if version == "StarCoder" else client_base | |
stream = model_client.generate_stream(prompt, **generate_kwargs) | |
output = prefix if fim_mode else prompt | |
for response in stream: | |
if response.token.text == END_OF_TEXT: | |
if fim_mode: | |
output += suffix | |
else: | |
return output | |
else: | |
output += response.token.text | |
# todo: log this value while in debug mode | |
# previous_token = response.token.text | |
yield output | |
return output | |
# todo: move it into the README too | |
examples = [ | |
"X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score", | |
"// Returns every other value in the array as a new array.\nfunction everyOther(arr) {", | |
"def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n <FILL_HERE>\n else:\n results.extend(list2[i+1:])\n return results", | |
] | |
def process_example(args): | |
for x in generate(args): | |
pass | |
return x | |
with gr.Blocks(theme=theme, analytics_enabled=False, css=CSS) as demo: | |
with gr.Column(): | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
instruction = gr.Textbox( | |
placeholder="Enter your code here", | |
label="Code", | |
elem_id="q-input", | |
) | |
submit = gr.Button("Generate", variant="primary") | |
output = gr.Code(elem_id="q-output", lines=30) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("Advanced settings", open=False): | |
with gr.Row(): | |
column_1, column_2 = gr.Column(), gr.Column() | |
with column_1: | |
temperature = gr.Slider( | |
label="Temperature", | |
value=0.2, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
) | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
value=256, | |
minimum=0, | |
maximum=8192, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
) | |
with column_2: | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.90, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition penalty", | |
value=1.2, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
with gr.Column(): | |
version = gr.Dropdown( | |
["StarCoderBase", "StarCoder"], | |
value="StarCoder", | |
label="Version", | |
info="", | |
) | |
gr.Markdown(disclaimer) | |
with gr.Group(elem_id="share-btn-container"): | |
community_icon = gr.HTML(community_icon_svg, visible=True) | |
loading_icon = gr.HTML(loading_icon_svg, visible=True) | |
share_button = gr.Button( | |
"Share to community", elem_id="share-btn", visible=True | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=[instruction], | |
cache_examples=False, | |
fn=process_example, | |
outputs=[output], | |
) | |
gr.Markdown(FORMATS) | |
submit.click( | |
generate, | |
inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, version], | |
outputs=[output], | |
) | |
share_button.click(None, [], [], _js=share_js) | |
demo.queue(concurrency_count=16).launch(debug=True, server_port=DEFAULT_PORT) | |