File size: 3,230 Bytes
462dacf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from flask import Flask, request
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
import os, glob
# Directory containing config.json, tokenizer.model and safetensors file for the model
model_directory = "/mnt/str/models/llama-7b-4bit/"
tokenizer_path = os.path.join(model_directory, "tokenizer.model")
model_config_path = os.path.join(model_directory, "config.json")
st_pattern = os.path.join(model_directory, "*.safetensors")
model_path = glob.glob(st_pattern)[0]
config = ExLlamaConfig(model_config_path) # create config from config.json
config.model_path = model_path # supply path to model weights file
model = ExLlama(config) # create ExLlama instance and load the weights
print(f"Model loaded: {model_path}")
tokenizer = ExLlamaTokenizer(tokenizer_path) # create tokenizer from tokenizer model file
cache = ExLlamaCache(model) # create cache for inference
generator = ExLlamaGenerator(model, tokenizer, cache) # create generator
# Flask app
app = Flask(__name__)
# Inference with settings equivalent to the "precise" preset from the /r/LocalLLaMA wiki
@app.route('/infer_precise', methods=['POST'])
def inferContextP():
print(request.form)
prompt = request.form.get('prompt')
generator.settings.token_repetition_penalty_max = 1.176
generator.settings.token_repetition_penalty_sustain = config.max_seq_len
generator.settings.temperature = 0.7
generator.settings.top_p = 0.1
generator.settings.top_k = 40
generator.settings.typical = 0.0 # Disabled
outputs = generator.generate_simple(prompt, max_new_tokens = 200)
return outputs
# Inference with settings equivalent to the "creative" preset from the /r/LocalLLaMA wiki
@app.route('/infer_creative', methods=['POST'])
def inferContextC():
print(request.form)
prompt = request.form.get('prompt')
generator.settings.token_repetition_penalty_max = 1.1
generator.settings.token_repetition_penalty_sustain = config.max_seq_len
generator.settings.temperature = 0.72
generator.settings.top_p = 0.73
generator.settings.top_k = 0 # Disabled
generator.settings.typical = 0.0 # Disabled
outputs = generator.generate_simple(prompt, max_new_tokens = 200)
return outputs
# Inference with settings equivalent to the "sphinx" preset from the /r/LocalLLaMA wiki
@app.route('/infer_sphinx', methods=['POST'])
def inferContextS():
print(request.form)
prompt = request.form.get('prompt')
generator.settings.token_repetition_penalty_max = 1.15
generator.settings.token_repetition_penalty_sustain = config.max_seq_len
generator.settings.temperature = 1.99
generator.settings.top_p = 0.18
generator.settings.top_k = 30
generator.settings.typical = 0.0 # Disabled
outputs = generator.generate_simple(prompt, max_new_tokens = 200)
return outputs
# Start Flask app
host = "0.0.0.0"
port = 8004
print(f"Starting server on address {host}:{port}")
if __name__ == '__main__':
from waitress import serve
serve(app, host = host, port = port)
|