Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | |
device = "cuda" | |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") | |
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b-slimpj", device=device, dtype=torch.float16) | |
genlen = 500 | |
def pred(text_in,): | |
tokens = tokenizer(text_in, return_tensors="pt") | |
input_ids = tokens.input_ids.to(device=device) | |
attn_mask = tokens.attention_mask.to(device=device) | |
max_length = input_ids.shape[1] + genlen | |
fn = lambda: model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
cg=True, | |
return_dict_in_generate=True, | |
output_scores=True, | |
enable_timing=False, | |
temperature=0.9, | |
top_p=0.7, | |
) | |
out = fn() | |
text_out = tokenizer.batch_decode(out.sequences.tolist()) | |
return text_out[0] | |
demo = gr.Interface( | |
title="Mamba: Selective State Space Model", | |
description="A demo for [Mamba](https://github.com/state-spaces/mamba) by Albert & Tri.", | |
fn=pred, inputs="text", outputs="text") | |
if __name__ == "__main__": | |
demo.launch() |