Delete main.rs
Browse files
main.rs
DELETED
@@ -1,252 +0,0 @@
|
|
1 |
-
#[cfg(feature = "mkl")]
|
2 |
-
extern crate intel_mkl_src;
|
3 |
-
|
4 |
-
#[cfg(feature = "accelerate")]
|
5 |
-
extern crate accelerate_src;
|
6 |
-
use std::io::Write;
|
7 |
-
use std::path::PathBuf;
|
8 |
-
|
9 |
-
use actix_web::{post, web, App, HttpResponse, HttpServer, Responder};
|
10 |
-
use serde::{Deserialize, Serialize};
|
11 |
-
|
12 |
-
use candle_transformers::models::quantized_t5 as t5;
|
13 |
-
|
14 |
-
use anyhow::{Error as E, Result};
|
15 |
-
use candle_core::{Device, Tensor};
|
16 |
-
use candle_transformers::generation::LogitsProcessor;
|
17 |
-
use clap::{Parser, ValueEnum};
|
18 |
-
use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
|
19 |
-
use tokenizers::Tokenizer;
|
20 |
-
|
21 |
-
#[derive(Clone, Debug, Copy, ValueEnum)]
|
22 |
-
enum Which {
|
23 |
-
T5Small,
|
24 |
-
FlanT5Small,
|
25 |
-
FlanT5Base,
|
26 |
-
FlanT5Large,
|
27 |
-
FlanT5Xl,
|
28 |
-
FlanT5Xxl,
|
29 |
-
}
|
30 |
-
|
31 |
-
#[derive(Parser, Debug, Clone)]
|
32 |
-
#[command(author, version, about, long_about = None)]
|
33 |
-
|
34 |
-
struct Args {
|
35 |
-
/// Enable tracing (generates a trace-timestamp.json file).
|
36 |
-
#[arg(long)]
|
37 |
-
tracing: bool,
|
38 |
-
|
39 |
-
/// The model repository to use on the HuggingFace hub.
|
40 |
-
#[arg(long)]
|
41 |
-
model_id: Option<String>,
|
42 |
-
|
43 |
-
#[arg(long)]
|
44 |
-
revision: Option<String>,
|
45 |
-
|
46 |
-
#[arg(long)]
|
47 |
-
weight_file: Option<String>,
|
48 |
-
|
49 |
-
#[arg(long)]
|
50 |
-
config_file: Option<String>,
|
51 |
-
|
52 |
-
// Enable/disable decoding.
|
53 |
-
#[arg(long, default_value = "false")]
|
54 |
-
disable_cache: bool,
|
55 |
-
|
56 |
-
/// Use this prompt, otherwise compute sentence similarities.
|
57 |
-
// #[arg(long)]
|
58 |
-
// prompt: Option<String>,
|
59 |
-
|
60 |
-
/// The temperature used to generate samples.
|
61 |
-
#[arg(long, default_value_t = 0.8)]
|
62 |
-
temperature: f64,
|
63 |
-
|
64 |
-
/// Nucleus sampling probability cutoff.
|
65 |
-
#[arg(long)]
|
66 |
-
top_p: Option<f64>,
|
67 |
-
|
68 |
-
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
69 |
-
#[arg(long, default_value_t = 1.1)]
|
70 |
-
repeat_penalty: f32,
|
71 |
-
|
72 |
-
/// The context size to consider for the repeat penalty.
|
73 |
-
#[arg(long, default_value_t = 64)]
|
74 |
-
repeat_last_n: usize,
|
75 |
-
|
76 |
-
/// The model size to use.
|
77 |
-
#[arg(long, default_value = "flan-t5-xl")]
|
78 |
-
which: Which,
|
79 |
-
}
|
80 |
-
|
81 |
-
struct T5ModelBuilder {
|
82 |
-
device: Device,
|
83 |
-
config: t5::Config,
|
84 |
-
weights_filename: PathBuf,
|
85 |
-
}
|
86 |
-
|
87 |
-
impl T5ModelBuilder {
|
88 |
-
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
89 |
-
let device = Device::Cpu;
|
90 |
-
let default_model = "deepfile/flan-t5-xl-gguf".to_string();
|
91 |
-
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
92 |
-
(Some(model_id), Some(revision)) => (model_id, revision),
|
93 |
-
(Some(model_id), None) => (model_id, "main".to_string()),
|
94 |
-
(None, Some(revision)) => (default_model, revision),
|
95 |
-
(None, None) => (default_model, "main".to_string()),
|
96 |
-
};
|
97 |
-
|
98 |
-
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
99 |
-
let api = Api::new()?;
|
100 |
-
let api = api.repo(repo);
|
101 |
-
let config_filename = match &args.config_file {
|
102 |
-
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
103 |
-
None => match args.which {
|
104 |
-
Which::T5Small => api.get("config.json")?,
|
105 |
-
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
|
106 |
-
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
|
107 |
-
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
|
108 |
-
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
|
109 |
-
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
|
110 |
-
},
|
111 |
-
};
|
112 |
-
let tokenizer_filename = api.get("tokenizer.json")?;
|
113 |
-
let weights_filename = match &args.weight_file {
|
114 |
-
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
115 |
-
None => match args.which {
|
116 |
-
Which::T5Small => api.get("model.gguf")?,
|
117 |
-
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
|
118 |
-
Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
|
119 |
-
Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
|
120 |
-
Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
|
121 |
-
Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
|
122 |
-
},
|
123 |
-
};
|
124 |
-
|
125 |
-
let config = std::fs::read_to_string(config_filename)?;
|
126 |
-
let mut config: t5::Config = serde_json::from_str(&config)?;
|
127 |
-
config.use_cache = !args.disable_cache;
|
128 |
-
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
129 |
-
Ok((
|
130 |
-
Self {
|
131 |
-
device,
|
132 |
-
config,
|
133 |
-
weights_filename,
|
134 |
-
},
|
135 |
-
tokenizer,
|
136 |
-
))
|
137 |
-
}
|
138 |
-
|
139 |
-
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
140 |
-
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
141 |
-
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
142 |
-
}
|
143 |
-
|
144 |
-
fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
|
145 |
-
let local_filename = std::path::PathBuf::from(filename);
|
146 |
-
if local_filename.exists() {
|
147 |
-
Ok(local_filename)
|
148 |
-
} else {
|
149 |
-
Ok(api.get(filename)?)
|
150 |
-
}
|
151 |
-
}
|
152 |
-
}
|
153 |
-
fn generate_answer(_prompt: String, args: &Args) -> Result<String> {
|
154 |
-
|
155 |
-
let mut generated_text = String::new();
|
156 |
-
|
157 |
-
let (_builder, mut _tokenizer) = T5ModelBuilder::load(&args)?;
|
158 |
-
let device = &_builder.device;
|
159 |
-
let _tokenizer = _tokenizer
|
160 |
-
.with_padding(None)
|
161 |
-
.with_truncation(None)
|
162 |
-
.map_err(E::msg)?;
|
163 |
-
let _tokens = _tokenizer
|
164 |
-
.encode(_prompt, true)
|
165 |
-
.map_err(E::msg)?
|
166 |
-
.get_ids()
|
167 |
-
.to_vec();
|
168 |
-
let input_token_ids = Tensor::new(&_tokens[..], device)?.unsqueeze(0)?;
|
169 |
-
let mut model = _builder.build_model()?;
|
170 |
-
let mut output_token_ids = [_builder.config.pad_token_id as u32].to_vec();
|
171 |
-
let temperature = 0.8f64;
|
172 |
-
|
173 |
-
let mut logits_processor = LogitsProcessor::new(299792458, Some(temperature), None);
|
174 |
-
let encoder_output = model.encode(&input_token_ids)?;
|
175 |
-
|
176 |
-
let start = std::time::Instant::now();
|
177 |
-
|
178 |
-
for index in 0.. {
|
179 |
-
|
180 |
-
if output_token_ids.len() > 512 {
|
181 |
-
break;
|
182 |
-
}
|
183 |
-
let decoder_token_ids = if index == 0 || !_builder.config.use_cache {
|
184 |
-
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
185 |
-
} else {
|
186 |
-
let last_token = *output_token_ids.last().unwrap();
|
187 |
-
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
188 |
-
};
|
189 |
-
let logits = model
|
190 |
-
.decode(&decoder_token_ids, &encoder_output)?
|
191 |
-
.squeeze(0)?;
|
192 |
-
let logits = if args.repeat_penalty == 1. {
|
193 |
-
logits
|
194 |
-
} else {
|
195 |
-
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
196 |
-
candle_transformers::utils::apply_repeat_penalty(
|
197 |
-
&logits,
|
198 |
-
args.repeat_penalty,
|
199 |
-
&output_token_ids[start_at..],
|
200 |
-
)?
|
201 |
-
};
|
202 |
-
|
203 |
-
let next_token_id = logits_processor.sample(&logits)?;
|
204 |
-
if next_token_id as usize == _builder.config.eos_token_id {
|
205 |
-
break;
|
206 |
-
}
|
207 |
-
output_token_ids.push(next_token_id);
|
208 |
-
if let Some(text) = _tokenizer.id_to_token(next_token_id) {
|
209 |
-
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
210 |
-
generated_text.push_str(&text);
|
211 |
-
print!("{}", text);
|
212 |
-
std::io::stdout().flush()?;
|
213 |
-
}
|
214 |
-
}
|
215 |
-
let dt = start.elapsed();
|
216 |
-
println!(
|
217 |
-
"\n{} tokens generated ({:.2} token/s)\n",
|
218 |
-
output_token_ids.len(),
|
219 |
-
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
220 |
-
);
|
221 |
-
|
222 |
-
Ok(generated_text)
|
223 |
-
}
|
224 |
-
|
225 |
-
// request struct
|
226 |
-
#[derive(Deserialize)]
|
227 |
-
struct Request {
|
228 |
-
prompt: String,
|
229 |
-
}
|
230 |
-
|
231 |
-
#[derive(Serialize)]
|
232 |
-
struct Response {
|
233 |
-
answer: String,
|
234 |
-
}
|
235 |
-
|
236 |
-
#[post("/generate")]
|
237 |
-
async fn generate(req_body: web::Json<Request>) -> impl Responder {
|
238 |
-
let args = Args::parse();
|
239 |
-
let generated_answer = generate_answer(req_body.prompt.clone(), &args);
|
240 |
-
HttpResponse::Ok().json(Response {
|
241 |
-
answer: generated_answer.unwrap(),
|
242 |
-
})
|
243 |
-
}
|
244 |
-
|
245 |
-
#[actix_web::main]
|
246 |
-
async fn main() -> std::io::Result<()> {
|
247 |
-
println!("Starting server at: http://localhost:7000");
|
248 |
-
HttpServer::new(|| App::new().service(generate))
|
249 |
-
.bind("localhost:7000")?
|
250 |
-
.run()
|
251 |
-
.await
|
252 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|