bayang commited on
Commit
a62520a
1 Parent(s): 932850e

Upload main.rs

Browse files
Files changed (1) hide show
  1. main.rs +252 -0
main.rs ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #[cfg(feature = "mkl")]
2
+ extern crate intel_mkl_src;
3
+
4
+ #[cfg(feature = "accelerate")]
5
+ extern crate accelerate_src;
6
+ use std::io::Write;
7
+ use std::path::PathBuf;
8
+
9
+ use actix_web::{post, web, App, HttpResponse, HttpServer, Responder};
10
+ use serde::{Deserialize, Serialize};
11
+
12
+ use candle_transformers::models::quantized_t5 as t5;
13
+
14
+ use anyhow::{Error as E, Result};
15
+ use candle_core::{Device, Tensor};
16
+ use candle_transformers::generation::LogitsProcessor;
17
+ use clap::{Parser, ValueEnum};
18
+ use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
19
+ use tokenizers::Tokenizer;
20
+
21
+ #[derive(Clone, Debug, Copy, ValueEnum)]
22
+ enum Which {
23
+ T5Small,
24
+ FlanT5Small,
25
+ FlanT5Base,
26
+ FlanT5Large,
27
+ FlanT5Xl,
28
+ FlanT5Xxl,
29
+ }
30
+
31
+ #[derive(Parser, Debug, Clone)]
32
+ #[command(author, version, about, long_about = None)]
33
+
34
+ struct Args {
35
+ /// Enable tracing (generates a trace-timestamp.json file).
36
+ #[arg(long)]
37
+ tracing: bool,
38
+
39
+ /// The model repository to use on the HuggingFace hub.
40
+ #[arg(long)]
41
+ model_id: Option<String>,
42
+
43
+ #[arg(long)]
44
+ revision: Option<String>,
45
+
46
+ #[arg(long)]
47
+ weight_file: Option<String>,
48
+
49
+ #[arg(long)]
50
+ config_file: Option<String>,
51
+
52
+ // Enable/disable decoding.
53
+ #[arg(long, default_value = "false")]
54
+ disable_cache: bool,
55
+
56
+ /// Use this prompt, otherwise compute sentence similarities.
57
+ // #[arg(long)]
58
+ // prompt: Option<String>,
59
+
60
+ /// The temperature used to generate samples.
61
+ #[arg(long, default_value_t = 0.8)]
62
+ temperature: f64,
63
+
64
+ /// Nucleus sampling probability cutoff.
65
+ #[arg(long)]
66
+ top_p: Option<f64>,
67
+
68
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
69
+ #[arg(long, default_value_t = 1.1)]
70
+ repeat_penalty: f32,
71
+
72
+ /// The context size to consider for the repeat penalty.
73
+ #[arg(long, default_value_t = 64)]
74
+ repeat_last_n: usize,
75
+
76
+ /// The model size to use.
77
+ #[arg(long, default_value = "flan-t5-xl")]
78
+ which: Which,
79
+ }
80
+
81
+ struct T5ModelBuilder {
82
+ device: Device,
83
+ config: t5::Config,
84
+ weights_filename: PathBuf,
85
+ }
86
+
87
+ impl T5ModelBuilder {
88
+ pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
89
+ let device = Device::Cpu;
90
+ let default_model = "deepfile/flan-t5-xl-gguf".to_string();
91
+ let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
92
+ (Some(model_id), Some(revision)) => (model_id, revision),
93
+ (Some(model_id), None) => (model_id, "main".to_string()),
94
+ (None, Some(revision)) => (default_model, revision),
95
+ (None, None) => (default_model, "main".to_string()),
96
+ };
97
+
98
+ let repo = Repo::with_revision(model_id, RepoType::Model, revision);
99
+ let api = Api::new()?;
100
+ let api = api.repo(repo);
101
+ let config_filename = match &args.config_file {
102
+ Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
103
+ None => match args.which {
104
+ Which::T5Small => api.get("config.json")?,
105
+ Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
106
+ Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
107
+ Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
108
+ Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
109
+ Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
110
+ },
111
+ };
112
+ let tokenizer_filename = api.get("tokenizer.json")?;
113
+ let weights_filename = match &args.weight_file {
114
+ Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
115
+ None => match args.which {
116
+ Which::T5Small => api.get("model.gguf")?,
117
+ Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
118
+ Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
119
+ Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
120
+ Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
121
+ Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
122
+ },
123
+ };
124
+
125
+ let config = std::fs::read_to_string(config_filename)?;
126
+ let mut config: t5::Config = serde_json::from_str(&config)?;
127
+ config.use_cache = !args.disable_cache;
128
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
129
+ Ok((
130
+ Self {
131
+ device,
132
+ config,
133
+ weights_filename,
134
+ },
135
+ tokenizer,
136
+ ))
137
+ }
138
+
139
+ pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
140
+ let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
141
+ Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
142
+ }
143
+
144
+ fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
145
+ let local_filename = std::path::PathBuf::from(filename);
146
+ if local_filename.exists() {
147
+ Ok(local_filename)
148
+ } else {
149
+ Ok(api.get(filename)?)
150
+ }
151
+ }
152
+ }
153
+ fn generate_answer(_prompt: String, args: &Args) -> Result<String> {
154
+
155
+ let mut generated_text = String::new();
156
+
157
+ let (_builder, mut _tokenizer) = T5ModelBuilder::load(&args)?;
158
+ let device = &_builder.device;
159
+ let _tokenizer = _tokenizer
160
+ .with_padding(None)
161
+ .with_truncation(None)
162
+ .map_err(E::msg)?;
163
+ let _tokens = _tokenizer
164
+ .encode(_prompt, true)
165
+ .map_err(E::msg)?
166
+ .get_ids()
167
+ .to_vec();
168
+ let input_token_ids = Tensor::new(&_tokens[..], device)?.unsqueeze(0)?;
169
+ let mut model = _builder.build_model()?;
170
+ let mut output_token_ids = [_builder.config.pad_token_id as u32].to_vec();
171
+ let temperature = 0.8f64;
172
+
173
+ let mut logits_processor = LogitsProcessor::new(299792458, Some(temperature), None);
174
+ let encoder_output = model.encode(&input_token_ids)?;
175
+
176
+ let start = std::time::Instant::now();
177
+
178
+ for index in 0.. {
179
+
180
+ if output_token_ids.len() > 512 {
181
+ break;
182
+ }
183
+ let decoder_token_ids = if index == 0 || !_builder.config.use_cache {
184
+ Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
185
+ } else {
186
+ let last_token = *output_token_ids.last().unwrap();
187
+ Tensor::new(&[last_token], device)?.unsqueeze(0)?
188
+ };
189
+ let logits = model
190
+ .decode(&decoder_token_ids, &encoder_output)?
191
+ .squeeze(0)?;
192
+ let logits = if args.repeat_penalty == 1. {
193
+ logits
194
+ } else {
195
+ let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
196
+ candle_transformers::utils::apply_repeat_penalty(
197
+ &logits,
198
+ args.repeat_penalty,
199
+ &output_token_ids[start_at..],
200
+ )?
201
+ };
202
+
203
+ let next_token_id = logits_processor.sample(&logits)?;
204
+ if next_token_id as usize == _builder.config.eos_token_id {
205
+ break;
206
+ }
207
+ output_token_ids.push(next_token_id);
208
+ if let Some(text) = _tokenizer.id_to_token(next_token_id) {
209
+ let text = text.replace('▁', " ").replace("<0x0A>", "\n");
210
+ generated_text.push_str(&text);
211
+ print!("{}", text);
212
+ std::io::stdout().flush()?;
213
+ }
214
+ }
215
+ let dt = start.elapsed();
216
+ println!(
217
+ "\n{} tokens generated ({:.2} token/s)\n",
218
+ output_token_ids.len(),
219
+ output_token_ids.len() as f64 / dt.as_secs_f64(),
220
+ );
221
+
222
+ Ok(generated_text)
223
+ }
224
+
225
+ // request struct
226
+ #[derive(Deserialize)]
227
+ struct Request {
228
+ prompt: String,
229
+ }
230
+
231
+ #[derive(Serialize)]
232
+ struct Response {
233
+ answer: String,
234
+ }
235
+
236
+ #[post("/generate")]
237
+ async fn generate(req_body: web::Json<Request>) -> impl Responder {
238
+ let args = Args::parse();
239
+ let generated_answer = generate_answer(req_body.prompt.clone(), &args);
240
+ HttpResponse::Ok().json(Response {
241
+ answer: generated_answer.unwrap(),
242
+ })
243
+ }
244
+
245
+ #[actix_web::main]
246
+ async fn main() -> std::io::Result<()> {
247
+ println!("Starting server at: http://localhost:7000");
248
+ HttpServer::new(|| App::new().service(generate))
249
+ .bind("localhost:7000")?
250
+ .run()
251
+ .await
252
+ }