jmercat commited on
Commit
5ae38c5
1 Parent(s): 0a1de44

first commit

Browse files
Files changed (2) hide show
  1. app.py +144 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from threading import Thread
3
+ from open_lm.hf import *
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5
+ import torch
6
+ from gradio.layouts import Accordion
7
+
8
+ # Define model options
9
+ MODEL_OPTIONS = {
10
+ "TRI-ML/DCLM-1B": "TRI-ML/DCLM-1B",
11
+ "Apple DCLM-Baseline-7B": "apple/DCLM-Baseline-7B"
12
+ }
13
+
14
+ # Global variables for model and tokenizer
15
+ current_model = None
16
+ current_tokenizer = None
17
+
18
+ def load_model(model_name):
19
+ global current_model, current_tokenizer
20
+ current_tokenizer = AutoTokenizer.from_pretrained(MODEL_OPTIONS[model_name])
21
+ current_model = AutoModelForCausalLM.from_pretrained(MODEL_OPTIONS[model_name])
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ current_model = current_model.to(device)
24
+ return f"Loaded model: {model_name}"
25
+
26
+ def generate(
27
+ prompt, model_choice, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
28
+ ):
29
+ global current_model, current_tokenizer
30
+
31
+ if current_model is None or current_tokenizer is None:
32
+ return "Please load a model first."
33
+
34
+ temperature = float(temperature)
35
+ if temperature < 1e-2:
36
+ temperature = 1e-2
37
+ top_p = float(top_p)
38
+
39
+ inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
40
+
41
+ generate_kwargs = dict(
42
+ **inputs,
43
+ max_new_tokens=max_new_tokens,
44
+ temperature=temperature,
45
+ top_p=top_p,
46
+ repetition_penalty=repetition_penalty,
47
+ do_sample=True,
48
+ pad_token_id=current_tokenizer.eos_token_id
49
+ )
50
+
51
+ streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=True)
52
+ generate_kwargs["streamer"] = streamer
53
+
54
+ thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
55
+ thread.start()
56
+
57
+ # Write the prompt in blue
58
+ output = "<span style='color: blue;'>" + prompt + "</span>"
59
+ for new_text in streamer:
60
+ if isinstance(new_text, torch.Tensor):
61
+ new_text = current_tokenizer.decode(new_text)
62
+ output += new_text
63
+ yield output
64
+
65
+ thread.join()
66
+ return output
67
+
68
+ additional_inputs=[
69
+ gr.Slider(
70
+ label="Temperature",
71
+ value=0.9,
72
+ minimum=0.0,
73
+ maximum=1.0,
74
+ step=0.05,
75
+ interactive=True,
76
+ info="Higher values produce more diverse outputs",
77
+ ),
78
+ gr.Slider(
79
+ label="Max new tokens",
80
+ value=256,
81
+ minimum=0,
82
+ maximum=1048,
83
+ step=64,
84
+ interactive=True,
85
+ info="The maximum numbers of new tokens",
86
+ ),
87
+ gr.Slider(
88
+ label="Top-p (nucleus sampling)",
89
+ value=0.90,
90
+ minimum=0.0,
91
+ maximum=1,
92
+ step=0.05,
93
+ interactive=True,
94
+ info="Higher values sample more low-probability tokens",
95
+ ),
96
+ gr.Slider(
97
+ label="Repetition penalty",
98
+ value=1.2,
99
+ minimum=1.0,
100
+ maximum=2.0,
101
+ step=0.05,
102
+ interactive=True,
103
+ info="Penalize repeated tokens",
104
+ )
105
+ ]
106
+
107
+ with gr.Blocks() as demo:
108
+ gr.Markdown(
109
+ """
110
+ # DCLM Text Completion Demo
111
+ This demo allows you to generate text using a DCLM model.
112
+ These models are trained to predict the next word in a sequence of text, and can be used to generate text completions, they are not chatbots.
113
+
114
+ First select a model from the dropdown and click "Load Model".
115
+ Then enter some text in the text box and click "Generate" to see the model's completion.
116
+ """
117
+ )
118
+
119
+
120
+ with gr.Row():
121
+ model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model")
122
+
123
+ model_dropdown.select(
124
+ load_model,
125
+ inputs=[model_dropdown],
126
+ outputs=[gr.Textbox(label="Model Status")]
127
+ )
128
+
129
+ text_input = gr.Textbox(lines=3, label="Input Text")
130
+ text_output = gr.HTML(label="Generated Text")
131
+
132
+ generate_button = gr.Button("Generate")
133
+
134
+ generate_button.click(
135
+ generate,
136
+ inputs=[text_input, model_dropdown, *additional_inputs],
137
+ outputs=[text_output]
138
+ )
139
+ with Accordion(label="Advanced Options", open=False):
140
+ for input_component in additional_inputs:
141
+ if not input_component.is_rendered:
142
+ input_component.render()
143
+
144
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ git+https://github.com/mlfoundations/open_lm.git
2
+ gradio
3
+ transformers
4
+ torch