florentgbelidji HF staff commited on
Commit
b810035
1 Parent(s): 421bc1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -39,17 +39,17 @@ def conditional_compile(func):
39
  compiled = torch.compile(func, backend="openxla")
40
  return compiled
41
  return func
42
-
43
 
 
 
 
 
 
 
44
 
45
 
46
- def summarize(inp):
47
- model_id = "google/gemma-2b"
48
- torch_dtype = torch.bfloat16
49
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
50
- device = model.device
51
- model = model.eval()
52
- tokenizer = AutoTokenizer.from_pretrained(model_id)
53
  with torch.no_grad():
54
  inp = inp.replace('\n','')
55
  inputs = tokenizer(inp, return_tensors="pt", padding=True).to(device)
@@ -106,3 +106,4 @@ def summarize(inp):
106
  return response
107
 
108
  gr.Interface(fn=summarize, inputs=gr.Textbox(lines=7, label="Input Text"), outputs="text", title="gemma-2b Demo").launch(inline=False)
 
 
39
  compiled = torch.compile(func, backend="openxla")
40
  return compiled
41
  return func
 
42
 
43
+ model_id = "google/gemma-2b"
44
+ torch_dtype = torch.bfloat16
45
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
46
+ device = model.device
47
+ model = model.eval()
48
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
49
 
50
 
51
+
52
+ def summarize(inp, model=model, tokenizer=tokenizer, device=device):
 
 
 
 
 
53
  with torch.no_grad():
54
  inp = inp.replace('\n','')
55
  inputs = tokenizer(inp, return_tensors="pt", padding=True).to(device)
 
106
  return response
107
 
108
  gr.Interface(fn=summarize, inputs=gr.Textbox(lines=7, label="Input Text"), outputs="text", title="gemma-2b Demo").launch(inline=False)
109
+