google-tpu / app.py
florentgbelidji's picture
Update app.py
03f2a71 verified
import gradio as gr
from transformers import BartTokenizer, BartForConditionalGeneration
import datetime
import os
import time
from typing import List
import torch
import torch_xla.core.xla_model as xm
from transformers import AutoTokenizer, StaticCache
from optimum.tpu.modeling import AutoModelForCausalLM
os.environ["PJRT_DEVICE"] = "TPU"
def sample_greedy(logits):
next_logits = logits[:, -1]
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int()
return next_token_id
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model(
cur_token,
position_ids=input_pos,
cache_position=cache_position,
return_dict=False,
use_cache=True,
past_key_values=past_key_values,
)[0]
new_token = sample_greedy(logits)
return new_token
def conditional_compile(func):
if "DBG_COMPILE" in os.environ:
compiled = torch.compile(func, backend="openxla")
return compiled
return func
model_id = "google/gemma-2b"
torch_dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
device = model.device
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
def summarize(inp, model=model, tokenizer=tokenizer, device=device):
with torch.no_grad():
inp = inp.replace('\n','')
inputs = tokenizer(inp, return_tensors="pt", padding=True).to(device)
batch_size, sequence_length = inputs["input_ids"].shape
max_cache_length = 1024
max_new_tokens = 64
# setup static cache
past_key_values = StaticCache(
config=model.config,
max_batch_size=batch_size,
max_cache_len=max_cache_length,
device=model.device,
dtype=model.dtype,
)
cache_position = torch.arange(sequence_length, device=device)
generated_ids = torch.zeros(
(batch_size, sequence_length + max_new_tokens + 1),
dtype=torch.int,
device=device,
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch.int)
# prefill here
attention_mask = inputs["attention_mask"]
pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
logits = model(
**inputs,
cache_position=cache_position,
return_dict=False,
use_cache=True,
position_ids=pos_ids,
past_key_values=past_key_values,
)[0]
next_token = sample_greedy(logits)
xm.mark_step()
generated_ids[:, sequence_length] = next_token[:, 0]
pos_ids = pos_ids.max(axis=-1)[0].unsqueeze(1) + 1
model = conditional_compile(model)
cache_position = torch.tensor([sequence_length], device=device)
for i in range(max_new_tokens):
next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position, past_key_values)
cache_position += 1
generated_ids[:, cache_position] = next_token
pos_ids += 1
xm.mark_step()
decoded_texts = tokenizer.batch_decode(generated_ids)
response = " ".join(decoded_texts)
return response
gr.Interface(fn=summarize, inputs=gr.Textbox(lines=7, label="Input Text"), outputs="text", title="gemma-2b simple TPU demo").launch(inline=False)