File size: 3,230 Bytes
462dacf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from flask import Flask, request
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
import os, glob

# Directory containing config.json, tokenizer.model and safetensors file for the model
model_directory = "/mnt/str/models/llama-7b-4bit/"

tokenizer_path = os.path.join(model_directory, "tokenizer.model")
model_config_path = os.path.join(model_directory, "config.json")
st_pattern = os.path.join(model_directory, "*.safetensors")
model_path = glob.glob(st_pattern)[0]

config = ExLlamaConfig(model_config_path)               # create config from config.json
config.model_path = model_path                          # supply path to model weights file

model = ExLlama(config)                                 # create ExLlama instance and load the weights
print(f"Model loaded: {model_path}")

tokenizer = ExLlamaTokenizer(tokenizer_path)            # create tokenizer from tokenizer model file
cache = ExLlamaCache(model)                             # create cache for inference
generator = ExLlamaGenerator(model, tokenizer, cache)   # create generator

# Flask app

app = Flask(__name__)


# Inference with settings equivalent to the "precise" preset from the /r/LocalLLaMA wiki

@app.route('/infer_precise', methods=['POST'])
def inferContextP():
    print(request.form)
    prompt = request.form.get('prompt')

    generator.settings.token_repetition_penalty_max = 1.176
    generator.settings.token_repetition_penalty_sustain = config.max_seq_len
    generator.settings.temperature = 0.7
    generator.settings.top_p = 0.1
    generator.settings.top_k = 40
    generator.settings.typical = 0.0    # Disabled

    outputs = generator.generate_simple(prompt, max_new_tokens = 200)
    return outputs


# Inference with settings equivalent to the "creative" preset from the /r/LocalLLaMA wiki

@app.route('/infer_creative', methods=['POST'])
def inferContextC():
    print(request.form)
    prompt = request.form.get('prompt')

    generator.settings.token_repetition_penalty_max = 1.1
    generator.settings.token_repetition_penalty_sustain = config.max_seq_len
    generator.settings.temperature = 0.72
    generator.settings.top_p = 0.73
    generator.settings.top_k = 0        # Disabled
    generator.settings.typical = 0.0    # Disabled

    outputs = generator.generate_simple(prompt, max_new_tokens = 200)
    return outputs


# Inference with settings equivalent to the "sphinx" preset from the /r/LocalLLaMA wiki

@app.route('/infer_sphinx', methods=['POST'])
def inferContextS():
    print(request.form)
    prompt = request.form.get('prompt')

    generator.settings.token_repetition_penalty_max = 1.15
    generator.settings.token_repetition_penalty_sustain = config.max_seq_len
    generator.settings.temperature = 1.99
    generator.settings.top_p = 0.18
    generator.settings.top_k = 30
    generator.settings.typical = 0.0    # Disabled

    outputs = generator.generate_simple(prompt, max_new_tokens = 200)
    return outputs


# Start Flask app

host = "0.0.0.0"
port = 8004
print(f"Starting server on address {host}:{port}")

if __name__ == '__main__':
    from waitress import serve
    serve(app, host = host, port = port)