Spaces:
Build error
Build error
import gradio as gr | |
from transformers import BartTokenizer, BartForConditionalGeneration | |
import datetime | |
import os | |
import time | |
from typing import List | |
import torch | |
import torch_xla.core.xla_model as xm | |
from transformers import AutoTokenizer, StaticCache | |
from optimum.tpu.modeling import AutoModelForCausalLM | |
os.environ["PJRT_DEVICE"] = "TPU" | |
def sample_greedy(logits): | |
next_logits = logits[:, -1] | |
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int() | |
return next_token_id | |
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): | |
logits = model( | |
cur_token, | |
position_ids=input_pos, | |
cache_position=cache_position, | |
return_dict=False, | |
use_cache=True, | |
past_key_values=past_key_values, | |
)[0] | |
new_token = sample_greedy(logits) | |
return new_token | |
def conditional_compile(func): | |
if "DBG_COMPILE" in os.environ: | |
compiled = torch.compile(func, backend="openxla") | |
return compiled | |
return func | |
model_id = "google/gemma-2b" | |
torch_dtype = torch.bfloat16 | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype) | |
device = model.device | |
model = model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
def summarize(inp, model=model, tokenizer=tokenizer, device=device): | |
with torch.no_grad(): | |
inp = inp.replace('\n','') | |
inputs = tokenizer(inp, return_tensors="pt", padding=True).to(device) | |
batch_size, sequence_length = inputs["input_ids"].shape | |
max_cache_length = 1024 | |
max_new_tokens = 64 | |
# setup static cache | |
past_key_values = StaticCache( | |
config=model.config, | |
max_batch_size=batch_size, | |
max_cache_len=max_cache_length, | |
device=model.device, | |
dtype=model.dtype, | |
) | |
cache_position = torch.arange(sequence_length, device=device) | |
generated_ids = torch.zeros( | |
(batch_size, sequence_length + max_new_tokens + 1), | |
dtype=torch.int, | |
device=device, | |
) | |
generated_ids[:, cache_position] = inputs["input_ids"].to(torch.int) | |
# prefill here | |
attention_mask = inputs["attention_mask"] | |
pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) | |
logits = model( | |
**inputs, | |
cache_position=cache_position, | |
return_dict=False, | |
use_cache=True, | |
position_ids=pos_ids, | |
past_key_values=past_key_values, | |
)[0] | |
next_token = sample_greedy(logits) | |
xm.mark_step() | |
generated_ids[:, sequence_length] = next_token[:, 0] | |
pos_ids = pos_ids.max(axis=-1)[0].unsqueeze(1) + 1 | |
model = conditional_compile(model) | |
cache_position = torch.tensor([sequence_length], device=device) | |
for i in range(max_new_tokens): | |
next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position, past_key_values) | |
cache_position += 1 | |
generated_ids[:, cache_position] = next_token | |
pos_ids += 1 | |
xm.mark_step() | |
decoded_texts = tokenizer.batch_decode(generated_ids) | |
response = " ".join(decoded_texts) | |
return response | |
gr.Interface(fn=summarize, inputs=gr.Textbox(lines=7, label="Input Text"), outputs="text", title="gemma-2b simple TPU demo").launch(inline=False) | |