Spaces:
Runtime error
Runtime error
GPU inference
Browse files- app.py +1 -1
- requirements.txt +2 -1
- utils/generation.py +6 -1
app.py
CHANGED
@@ -51,7 +51,7 @@ outro_text ="""
|
|
51 |
- [] support FIM task for better model context
|
52 |
- [x] include some context for prompt (title, comments before a functions) - now takes all comments directly before a function as well as all comments at the beginning inside a function. (misses comments between argument list and body)
|
53 |
- [] gradio examples
|
54 |
-
- [] use GPU if available, respect memory restrictions.
|
55 |
- [x] stream model generation (maybe in a new window?) - janky solution and only sometimes hangs up
|
56 |
- [] 2nd iFrame needs a lot of fixing (I am not a web developer, need help) BUG:background is white, so colors are wrong. Shadertoy uses black background (or we ignore alpha).
|
57 |
- [] (optional) filtering the dataset by license?
|
|
|
51 |
- [] support FIM task for better model context
|
52 |
- [x] include some context for prompt (title, comments before a functions) - now takes all comments directly before a function as well as all comments at the beginning inside a function. (misses comments between argument list and body)
|
53 |
- [] gradio examples
|
54 |
+
- [x] use GPU if available, respect memory restrictions (implemented via accelerate.Accelerator.device in utils.generation.py), tested with A750 successfully!
|
55 |
- [x] stream model generation (maybe in a new window?) - janky solution and only sometimes hangs up
|
56 |
- [] 2nd iFrame needs a lot of fixing (I am not a web developer, need help) BUG:background is white, so colors are wrong. Shadertoy uses black background (or we ignore alpha).
|
57 |
- [] (optional) filtering the dataset by license?
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ torch
|
|
5 |
pillow
|
6 |
gradio
|
7 |
jupylet
|
8 |
-
tree-sitter
|
|
|
|
5 |
pillow
|
6 |
gradio
|
7 |
jupylet
|
8 |
+
tree-sitter
|
9 |
+
accelerate
|
utils/generation.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from transformers import TextIteratorStreamer
|
2 |
from threading import Thread
|
3 |
from .tree_utils import full_func_head, grab_before_comments
|
@@ -15,17 +16,21 @@ def combine_generation_kwargs(temperature=2.0, max_new_tokens=512, top_p=0.95, r
|
|
15 |
|
16 |
|
17 |
def stream_generation(prompt:str, pipe, gen_kwargs:dict):
|
|
|
|
|
18 |
"""
|
19 |
Text generation function
|
20 |
Args:
|
21 |
prompt (str): The context to start generation from.
|
22 |
-
pipe (Pipeline): The pipeline to use for generation
|
23 |
gen_kwargs (dict): The generation kwargs.
|
24 |
Returns:
|
25 |
str: The generated text. (it iterates over time)
|
26 |
"""
|
27 |
# Tokenize the model_context
|
28 |
model_inputs = pipe.tokenizer(prompt, return_tensors="pt")
|
|
|
|
|
29 |
|
30 |
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
|
31 |
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
|
|
|
1 |
+
from accelerate import Accelerator
|
2 |
from transformers import TextIteratorStreamer
|
3 |
from threading import Thread
|
4 |
from .tree_utils import full_func_head, grab_before_comments
|
|
|
16 |
|
17 |
|
18 |
def stream_generation(prompt:str, pipe, gen_kwargs:dict):
|
19 |
+
accelerator = Accelerator()
|
20 |
+
device = accelerator.device
|
21 |
"""
|
22 |
Text generation function
|
23 |
Args:
|
24 |
prompt (str): The context to start generation from.
|
25 |
+
pipe (Pipeline): The pipeline to use for generation (we take the model and tokenizer form it)
|
26 |
gen_kwargs (dict): The generation kwargs.
|
27 |
Returns:
|
28 |
str: The generated text. (it iterates over time)
|
29 |
"""
|
30 |
# Tokenize the model_context
|
31 |
model_inputs = pipe.tokenizer(prompt, return_tensors="pt")
|
32 |
+
model_inputs.to(device)
|
33 |
+
model = pipe.model.to(device) #is this also required?
|
34 |
|
35 |
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
|
36 |
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
|