|
from model import ExLlama, ExLlamaCache, ExLlamaConfig |
|
from flask import Flask, request |
|
from tokenizer import ExLlamaTokenizer |
|
from generator import ExLlamaGenerator |
|
import os, glob |
|
|
|
|
|
model_directory = "/mnt/str/models/llama-7b-4bit/" |
|
|
|
tokenizer_path = os.path.join(model_directory, "tokenizer.model") |
|
model_config_path = os.path.join(model_directory, "config.json") |
|
st_pattern = os.path.join(model_directory, "*.safetensors") |
|
model_path = glob.glob(st_pattern)[0] |
|
|
|
config = ExLlamaConfig(model_config_path) |
|
config.model_path = model_path |
|
|
|
model = ExLlama(config) |
|
print(f"Model loaded: {model_path}") |
|
|
|
tokenizer = ExLlamaTokenizer(tokenizer_path) |
|
cache = ExLlamaCache(model) |
|
generator = ExLlamaGenerator(model, tokenizer, cache) |
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
|
|
|
|
@app.route('/infer_precise', methods=['POST']) |
|
def inferContextP(): |
|
print(request.form) |
|
prompt = request.form.get('prompt') |
|
|
|
generator.settings.token_repetition_penalty_max = 1.176 |
|
generator.settings.token_repetition_penalty_sustain = config.max_seq_len |
|
generator.settings.temperature = 0.7 |
|
generator.settings.top_p = 0.1 |
|
generator.settings.top_k = 40 |
|
generator.settings.typical = 0.0 |
|
|
|
outputs = generator.generate_simple(prompt, max_new_tokens = 200) |
|
return outputs |
|
|
|
|
|
|
|
|
|
@app.route('/infer_creative', methods=['POST']) |
|
def inferContextC(): |
|
print(request.form) |
|
prompt = request.form.get('prompt') |
|
|
|
generator.settings.token_repetition_penalty_max = 1.1 |
|
generator.settings.token_repetition_penalty_sustain = config.max_seq_len |
|
generator.settings.temperature = 0.72 |
|
generator.settings.top_p = 0.73 |
|
generator.settings.top_k = 0 |
|
generator.settings.typical = 0.0 |
|
|
|
outputs = generator.generate_simple(prompt, max_new_tokens = 200) |
|
return outputs |
|
|
|
|
|
|
|
|
|
@app.route('/infer_sphinx', methods=['POST']) |
|
def inferContextS(): |
|
print(request.form) |
|
prompt = request.form.get('prompt') |
|
|
|
generator.settings.token_repetition_penalty_max = 1.15 |
|
generator.settings.token_repetition_penalty_sustain = config.max_seq_len |
|
generator.settings.temperature = 1.99 |
|
generator.settings.top_p = 0.18 |
|
generator.settings.top_k = 30 |
|
generator.settings.typical = 0.0 |
|
|
|
outputs = generator.generate_simple(prompt, max_new_tokens = 200) |
|
return outputs |
|
|
|
|
|
|
|
|
|
host = "0.0.0.0" |
|
port = 8004 |
|
print(f"Starting server on address {host}:{port}") |
|
|
|
if __name__ == '__main__': |
|
from waitress import serve |
|
serve(app, host = host, port = port) |
|
|