add multiprocessing in inference and clean code
Browse files
app.py
CHANGED
@@ -1,29 +1,27 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
-
from transformers import pipeline
|
4 |
-
import torch
|
5 |
import json
|
6 |
import pandas as pd
|
7 |
import requests
|
|
|
|
|
|
|
|
|
8 |
|
9 |
GITHUB_CODE = "https://huggingface.co/datasets/lvwerra/github-code"
|
10 |
INCODER_IMG = "https://huggingface.co/datasets/loubnabnl/repo-images/raw/main/incoder.png"
|
11 |
|
12 |
-
@st.cache(allow_output_mutation=True)
|
13 |
-
def load_tokenizer(model_ckpt):
|
14 |
-
return AutoTokenizer.from_pretrained(model_ckpt)
|
15 |
-
|
16 |
-
@st.cache(allow_output_mutation=True)
|
17 |
-
def load_model(model_ckpt):
|
18 |
-
model = AutoModelForCausalLM.from_pretrained(model_ckpt, low_cpu_mem_usage=True)
|
19 |
-
return model
|
20 |
-
|
21 |
@st.cache()
|
22 |
def load_examples():
|
23 |
with open("utils/examples.json", "r") as f:
|
24 |
examples = json.load(f)
|
25 |
return examples
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
st.set_page_config(page_icon=":laptop:", layout="wide")
|
28 |
|
29 |
st.sidebar.header("Models")
|
@@ -84,9 +82,11 @@ elif selected_task == "Code generation":
|
|
84 |
gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
|
85 |
if st.button("Generate code!"):
|
86 |
with st.spinner("Generating code..."):
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
import pandas as pd
|
3 |
import requests
|
4 |
+
from multiprocessing import Pool
|
5 |
+
from functools import partial
|
6 |
+
import streamlit as st
|
7 |
+
|
8 |
|
9 |
GITHUB_CODE = "https://huggingface.co/datasets/lvwerra/github-code"
|
10 |
INCODER_IMG = "https://huggingface.co/datasets/loubnabnl/repo-images/raw/main/incoder.png"
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
@st.cache()
|
13 |
def load_examples():
|
14 |
with open("utils/examples.json", "r") as f:
|
15 |
examples = json.load(f)
|
16 |
return examples
|
17 |
|
18 |
+
def generate_code(model_name, gen_prompt, max_new_tokens, temperature, seed):
|
19 |
+
url = f'https://hf.space/embed/loubnabnl/{model_name.lower()}-subspace/+/api/predict/'
|
20 |
+
r = requests.post(url=url, json={"data": [gen_prompt, max_new_tokens, temperature, seed]})
|
21 |
+
generated_text = r.json()['data'][0]
|
22 |
+
st.markdown(model_name)
|
23 |
+
st.code(generated_text)
|
24 |
+
|
25 |
st.set_page_config(page_icon=":laptop:", layout="wide")
|
26 |
|
27 |
st.sidebar.header("Models")
|
|
|
82 |
gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
|
83 |
if st.button("Generate code!"):
|
84 |
with st.spinner("Generating code..."):
|
85 |
+
# Create a multiprocessing Pool
|
86 |
+
pool = Pool()
|
87 |
+
generate_parallel=partial(generate_code,
|
88 |
+
gen_prompt=gen_prompt,
|
89 |
+
max_new_tokens=max_new_tokens,
|
90 |
+
temperature=temperature,
|
91 |
+
seed=seed)
|
92 |
+
pool.map(generate_parallel, selected_models)
|