#[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; use std::io::Write; use std::path::PathBuf; use actix_web::{post, web, App, HttpResponse, HttpServer, Responder}; use serde::{Deserialize, Serialize}; use candle_transformers::models::quantized_t5 as t5; use anyhow::{Error as E, Result}; use candle_core::{Device, Tensor}; use candle_transformers::generation::LogitsProcessor; use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType}; use tokenizers::Tokenizer; #[derive(Clone, Debug, Copy, ValueEnum)] enum Which { T5Small, FlanT5Small, FlanT5Base, FlanT5Large, FlanT5Xl, FlanT5Xxl, } #[derive(Parser, Debug, Clone)] #[command(author, version, about, long_about = None)] struct Args { /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, /// The model repository to use on the HuggingFace hub. #[arg(long)] model_id: Option, #[arg(long)] revision: Option, #[arg(long)] weight_file: Option, #[arg(long)] config_file: Option, // Enable/disable decoding. #[arg(long, default_value = "false")] disable_cache: bool, /// Use this prompt, otherwise compute sentence similarities. // #[arg(long)] // prompt: Option, /// The temperature used to generate samples. #[arg(long, default_value_t = 0.8)] temperature: f64, /// Nucleus sampling probability cutoff. #[arg(long)] top_p: Option, /// Penalty to be applied for repeating tokens, 1. means no penalty. #[arg(long, default_value_t = 1.1)] repeat_penalty: f32, /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, /// The model size to use. #[arg(long, default_value = "flan-t5-xl")] which: Which, } struct T5ModelBuilder { device: Device, config: t5::Config, weights_filename: PathBuf, } impl T5ModelBuilder { pub fn load(args: &Args) -> Result<(Self, Tokenizer)> { let device = Device::Cpu; let default_model = "deepfile/flan-t5-xl-gguf".to_string(); let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { (Some(model_id), Some(revision)) => (model_id, revision), (Some(model_id), None) => (model_id, "main".to_string()), (None, Some(revision)) => (default_model, revision), (None, None) => (default_model, "main".to_string()), }; let repo = Repo::with_revision(model_id, RepoType::Model, revision); let api = Api::new()?; let api = api.repo(repo); let config_filename = match &args.config_file { Some(filename) => Self::get_local_or_remote_file(filename, &api)?, None => match args.which { Which::T5Small => api.get("config.json")?, Which::FlanT5Small => api.get("config-flan-t5-small.json")?, Which::FlanT5Base => api.get("config-flan-t5-base.json")?, Which::FlanT5Large => api.get("config-flan-t5-large.json")?, Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?, Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?, }, }; let tokenizer_filename = api.get("tokenizer.json")?; let weights_filename = match &args.weight_file { Some(filename) => Self::get_local_or_remote_file(filename, &api)?, None => match args.which { Which::T5Small => api.get("model.gguf")?, Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?, Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?, Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?, Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?, Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?, }, }; let config = std::fs::read_to_string(config_filename)?; let mut config: t5::Config = serde_json::from_str(&config)?; config.use_cache = !args.disable_cache; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; Ok(( Self { device, config, weights_filename, }, tokenizer, )) } pub fn build_model(&self) -> Result { let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?; Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) } fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result { let local_filename = std::path::PathBuf::from(filename); if local_filename.exists() { Ok(local_filename) } else { Ok(api.get(filename)?) } } } fn generate_answer(_prompt: String, args: &Args) -> Result { let mut generated_text = String::new(); let (_builder, mut _tokenizer) = T5ModelBuilder::load(&args)?; let device = &_builder.device; let _tokenizer = _tokenizer .with_padding(None) .with_truncation(None) .map_err(E::msg)?; let _tokens = _tokenizer .encode(_prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let input_token_ids = Tensor::new(&_tokens[..], device)?.unsqueeze(0)?; let mut model = _builder.build_model()?; let mut output_token_ids = [_builder.config.pad_token_id as u32].to_vec(); let temperature = 0.8f64; let mut logits_processor = LogitsProcessor::new(299792458, Some(temperature), None); let encoder_output = model.encode(&input_token_ids)?; let start = std::time::Instant::now(); for index in 0.. { if output_token_ids.len() > 512 { break; } let decoder_token_ids = if index == 0 || !_builder.config.use_cache { Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)? } else { let last_token = *output_token_ids.last().unwrap(); Tensor::new(&[last_token], device)?.unsqueeze(0)? }; let logits = model .decode(&decoder_token_ids, &encoder_output)? .squeeze(0)?; let logits = if args.repeat_penalty == 1. { logits } else { let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n); candle_transformers::utils::apply_repeat_penalty( &logits, args.repeat_penalty, &output_token_ids[start_at..], )? }; let next_token_id = logits_processor.sample(&logits)?; if next_token_id as usize == _builder.config.eos_token_id { break; } output_token_ids.push(next_token_id); if let Some(text) = _tokenizer.id_to_token(next_token_id) { let text = text.replace('▁', " ").replace("<0x0A>", "\n"); generated_text.push_str(&text); print!("{}", text); std::io::stdout().flush()?; } } let dt = start.elapsed(); println!( "\n{} tokens generated ({:.2} token/s)\n", output_token_ids.len(), output_token_ids.len() as f64 / dt.as_secs_f64(), ); Ok(generated_text) } // request struct #[derive(Deserialize)] struct Request { prompt: String, } #[derive(Serialize)] struct Response { answer: String, } #[post("/generate")] async fn generate(req_body: web::Json) -> impl Responder { let args = Args::parse(); let generated_answer = generate_answer(req_body.prompt.clone(), &args); HttpResponse::Ok().json(Response { answer: generated_answer.unwrap(), }) } #[actix_web::main] async fn main() -> std::io::Result<()> { println!("Starting server at: http://localhost:7000"); HttpServer::new(|| App::new().service(generate)) .bind("localhost:7000")? .run() .await }