File size: 1,268 Bytes
b5c7703
 
7c9d0ce
b5c7703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae0b701
b5c7703
 
 
 
 
 
 
 
 
 
 
 
 
ae0b701
b5c7703
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from gradio import Interface, components
from transformers import AutoTokenizer, AutoModelForCausalLM

# 加载模型和tokenizer
tokenizer = AutoTokenizer.from_pretrained("raincandy-u/TinyStories-656K")
model = AutoModelForCausalLM.from_pretrained("raincandy-u/TinyStories-656K")

# 定义你的应用程序
def generate_story(input_text):
    input_text = f"<|start_story|>{input_text}"
    encoded_input = tokenizer(input_text, return_tensors="pt")
    output_sequences = model.generate(
        **encoded_input,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=512,
        do_sample=True,
        top_k=40,
        top_p=0.9,
        temperature=0.65
    )
    return tokenizer.decode(output_sequences[0], skip_special_tokens=True)

# 定义组件
input_component = components.Textbox(lines=10)
label = components.Label("Try it!\nNote: Most of the time the default beginning works well.")

# 定义Interface
interface = Interface(
    fn=generate_story,
    inputs=input_component,
    outputs="textbox",
    title="TinyStories-656K",
    description="Try it!\nNote: Most of the time the default beginning works well.",
    examples=[['Once upon a time, there was a girl '], ['Long time ago, ']],
    theme="gradio/light"
)

interface.launch()