Text2Text Generation
GGUF
bayang commited on
Commit
ae3f19e
1 Parent(s): a62520a

Delete main.rs

Browse files
Files changed (1) hide show
  1. main.rs +0 -252
main.rs DELETED
@@ -1,252 +0,0 @@
1
- #[cfg(feature = "mkl")]
2
- extern crate intel_mkl_src;
3
-
4
- #[cfg(feature = "accelerate")]
5
- extern crate accelerate_src;
6
- use std::io::Write;
7
- use std::path::PathBuf;
8
-
9
- use actix_web::{post, web, App, HttpResponse, HttpServer, Responder};
10
- use serde::{Deserialize, Serialize};
11
-
12
- use candle_transformers::models::quantized_t5 as t5;
13
-
14
- use anyhow::{Error as E, Result};
15
- use candle_core::{Device, Tensor};
16
- use candle_transformers::generation::LogitsProcessor;
17
- use clap::{Parser, ValueEnum};
18
- use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
19
- use tokenizers::Tokenizer;
20
-
21
- #[derive(Clone, Debug, Copy, ValueEnum)]
22
- enum Which {
23
- T5Small,
24
- FlanT5Small,
25
- FlanT5Base,
26
- FlanT5Large,
27
- FlanT5Xl,
28
- FlanT5Xxl,
29
- }
30
-
31
- #[derive(Parser, Debug, Clone)]
32
- #[command(author, version, about, long_about = None)]
33
-
34
- struct Args {
35
- /// Enable tracing (generates a trace-timestamp.json file).
36
- #[arg(long)]
37
- tracing: bool,
38
-
39
- /// The model repository to use on the HuggingFace hub.
40
- #[arg(long)]
41
- model_id: Option<String>,
42
-
43
- #[arg(long)]
44
- revision: Option<String>,
45
-
46
- #[arg(long)]
47
- weight_file: Option<String>,
48
-
49
- #[arg(long)]
50
- config_file: Option<String>,
51
-
52
- // Enable/disable decoding.
53
- #[arg(long, default_value = "false")]
54
- disable_cache: bool,
55
-
56
- /// Use this prompt, otherwise compute sentence similarities.
57
- // #[arg(long)]
58
- // prompt: Option<String>,
59
-
60
- /// The temperature used to generate samples.
61
- #[arg(long, default_value_t = 0.8)]
62
- temperature: f64,
63
-
64
- /// Nucleus sampling probability cutoff.
65
- #[arg(long)]
66
- top_p: Option<f64>,
67
-
68
- /// Penalty to be applied for repeating tokens, 1. means no penalty.
69
- #[arg(long, default_value_t = 1.1)]
70
- repeat_penalty: f32,
71
-
72
- /// The context size to consider for the repeat penalty.
73
- #[arg(long, default_value_t = 64)]
74
- repeat_last_n: usize,
75
-
76
- /// The model size to use.
77
- #[arg(long, default_value = "flan-t5-xl")]
78
- which: Which,
79
- }
80
-
81
- struct T5ModelBuilder {
82
- device: Device,
83
- config: t5::Config,
84
- weights_filename: PathBuf,
85
- }
86
-
87
- impl T5ModelBuilder {
88
- pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
89
- let device = Device::Cpu;
90
- let default_model = "deepfile/flan-t5-xl-gguf".to_string();
91
- let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
92
- (Some(model_id), Some(revision)) => (model_id, revision),
93
- (Some(model_id), None) => (model_id, "main".to_string()),
94
- (None, Some(revision)) => (default_model, revision),
95
- (None, None) => (default_model, "main".to_string()),
96
- };
97
-
98
- let repo = Repo::with_revision(model_id, RepoType::Model, revision);
99
- let api = Api::new()?;
100
- let api = api.repo(repo);
101
- let config_filename = match &args.config_file {
102
- Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
103
- None => match args.which {
104
- Which::T5Small => api.get("config.json")?,
105
- Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
106
- Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
107
- Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
108
- Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
109
- Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
110
- },
111
- };
112
- let tokenizer_filename = api.get("tokenizer.json")?;
113
- let weights_filename = match &args.weight_file {
114
- Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
115
- None => match args.which {
116
- Which::T5Small => api.get("model.gguf")?,
117
- Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
118
- Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
119
- Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
120
- Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
121
- Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
122
- },
123
- };
124
-
125
- let config = std::fs::read_to_string(config_filename)?;
126
- let mut config: t5::Config = serde_json::from_str(&config)?;
127
- config.use_cache = !args.disable_cache;
128
- let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
129
- Ok((
130
- Self {
131
- device,
132
- config,
133
- weights_filename,
134
- },
135
- tokenizer,
136
- ))
137
- }
138
-
139
- pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
140
- let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
141
- Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
142
- }
143
-
144
- fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
145
- let local_filename = std::path::PathBuf::from(filename);
146
- if local_filename.exists() {
147
- Ok(local_filename)
148
- } else {
149
- Ok(api.get(filename)?)
150
- }
151
- }
152
- }
153
- fn generate_answer(_prompt: String, args: &Args) -> Result<String> {
154
-
155
- let mut generated_text = String::new();
156
-
157
- let (_builder, mut _tokenizer) = T5ModelBuilder::load(&args)?;
158
- let device = &_builder.device;
159
- let _tokenizer = _tokenizer
160
- .with_padding(None)
161
- .with_truncation(None)
162
- .map_err(E::msg)?;
163
- let _tokens = _tokenizer
164
- .encode(_prompt, true)
165
- .map_err(E::msg)?
166
- .get_ids()
167
- .to_vec();
168
- let input_token_ids = Tensor::new(&_tokens[..], device)?.unsqueeze(0)?;
169
- let mut model = _builder.build_model()?;
170
- let mut output_token_ids = [_builder.config.pad_token_id as u32].to_vec();
171
- let temperature = 0.8f64;
172
-
173
- let mut logits_processor = LogitsProcessor::new(299792458, Some(temperature), None);
174
- let encoder_output = model.encode(&input_token_ids)?;
175
-
176
- let start = std::time::Instant::now();
177
-
178
- for index in 0.. {
179
-
180
- if output_token_ids.len() > 512 {
181
- break;
182
- }
183
- let decoder_token_ids = if index == 0 || !_builder.config.use_cache {
184
- Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
185
- } else {
186
- let last_token = *output_token_ids.last().unwrap();
187
- Tensor::new(&[last_token], device)?.unsqueeze(0)?
188
- };
189
- let logits = model
190
- .decode(&decoder_token_ids, &encoder_output)?
191
- .squeeze(0)?;
192
- let logits = if args.repeat_penalty == 1. {
193
- logits
194
- } else {
195
- let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
196
- candle_transformers::utils::apply_repeat_penalty(
197
- &logits,
198
- args.repeat_penalty,
199
- &output_token_ids[start_at..],
200
- )?
201
- };
202
-
203
- let next_token_id = logits_processor.sample(&logits)?;
204
- if next_token_id as usize == _builder.config.eos_token_id {
205
- break;
206
- }
207
- output_token_ids.push(next_token_id);
208
- if let Some(text) = _tokenizer.id_to_token(next_token_id) {
209
- let text = text.replace('▁', " ").replace("<0x0A>", "\n");
210
- generated_text.push_str(&text);
211
- print!("{}", text);
212
- std::io::stdout().flush()?;
213
- }
214
- }
215
- let dt = start.elapsed();
216
- println!(
217
- "\n{} tokens generated ({:.2} token/s)\n",
218
- output_token_ids.len(),
219
- output_token_ids.len() as f64 / dt.as_secs_f64(),
220
- );
221
-
222
- Ok(generated_text)
223
- }
224
-
225
- // request struct
226
- #[derive(Deserialize)]
227
- struct Request {
228
- prompt: String,
229
- }
230
-
231
- #[derive(Serialize)]
232
- struct Response {
233
- answer: String,
234
- }
235
-
236
- #[post("/generate")]
237
- async fn generate(req_body: web::Json<Request>) -> impl Responder {
238
- let args = Args::parse();
239
- let generated_answer = generate_answer(req_body.prompt.clone(), &args);
240
- HttpResponse::Ok().json(Response {
241
- answer: generated_answer.unwrap(),
242
- })
243
- }
244
-
245
- #[actix_web::main]
246
- async fn main() -> std::io::Result<()> {
247
- println!("Starting server at: http://localhost:7000");
248
- HttpServer::new(|| App::new().service(generate))
249
- .bind("localhost:7000")?
250
- .run()
251
- .await
252
- }