rajistics commited on
Commit
2923d3d
1 Parent(s): 08c6275
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -14,7 +14,6 @@ from share_btn import community_icon_html, loading_icon_html, share_js, share_bt
14
  #API_URL = "https://api-inference.huggingface.co/models/bigcode/starcoder"
15
  #API_URL_BASE ="https://api-inference.huggingface.co/models/bigcode/starcoderbase"
16
  #API_URL_PLUS = "https://api-inference.huggingface.co/models/bigcode/starcoderplus"
17
- https://huggingface.co/smallcloudai/Refact-1_6B-fim/discussions
18
 
19
  from transformers import AutoModelForCausalLM, AutoTokenizer
20
 
@@ -28,7 +27,7 @@ prompt = '<fim_prefix>def print_hello_world():\n """<fim_suffix>\n print("
28
 
29
  inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
30
  outputs = model.generate(inputs, max_length=100, temperature=0.2)
31
- print("-"*80)
32
  print(tokenizer.decode(outputs[0]))
33
 
34
 
@@ -123,9 +122,10 @@ def generate(
123
  prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
124
 
125
  inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
126
- output = model.generate(inputs, max_length=100, temperature=0.2)
 
127
 
128
- return output
129
 
130
 
131
  examples = [
 
14
  #API_URL = "https://api-inference.huggingface.co/models/bigcode/starcoder"
15
  #API_URL_BASE ="https://api-inference.huggingface.co/models/bigcode/starcoderbase"
16
  #API_URL_PLUS = "https://api-inference.huggingface.co/models/bigcode/starcoderplus"
 
17
 
18
  from transformers import AutoModelForCausalLM, AutoTokenizer
19
 
 
27
 
28
  inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
29
  outputs = model.generate(inputs, max_length=100, temperature=0.2)
30
+ #print("-"*80)
31
  print(tokenizer.decode(outputs[0]))
32
 
33
 
 
122
  prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
123
 
124
  inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
125
+ outputs = model.generate(inputs, max_length=100, temperature=0.2)
126
+ final = tokenizer.decode(outputs[0])
127
 
128
+ return final
129
 
130
 
131
  examples = [