Upload 7 files
Browse files- Dockerfile +17 -0
- app.py +46 -0
- config.py +40 -0
- http_api.py +59 -0
- requirements.txt +5 -0
- utils.py +11 -0
- websocket_api.py +90 -0
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Step 1: Choose a base image
|
2 |
+
FROM python:3.9
|
3 |
+
|
4 |
+
# Step 2: Set the working directory inside the container
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Step 3: Copy the application files to the container
|
8 |
+
COPY . /app
|
9 |
+
|
10 |
+
# Step 4: Install the required dependencies
|
11 |
+
RUN pip install -r requirements.txt
|
12 |
+
|
13 |
+
# Step 5: Expose the necessary port
|
14 |
+
EXPOSE 5000
|
15 |
+
|
16 |
+
# Step 6: Define the entry point command
|
17 |
+
CMD ["flask", "run", "--host=0.0.0.0", "--port=5000"]
|
app.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hivemind
|
2 |
+
from flask import Flask
|
3 |
+
from flask_cors import CORS
|
4 |
+
from flask_sock import Sock
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
|
7 |
+
from petals import AutoDistributedModelForCausalLM
|
8 |
+
|
9 |
+
import config
|
10 |
+
|
11 |
+
|
12 |
+
logger = hivemind.get_logger(__file__)
|
13 |
+
|
14 |
+
models = {}
|
15 |
+
for model_info in config.MODELS:
|
16 |
+
logger.info(f"Loading tokenizer for {model_info.repo}")
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_info.repo, add_bos_token=False, use_fast=False)
|
18 |
+
|
19 |
+
logger.info(f"Loading model {model_info.repo} with adapter {model_info.adapter} and dtype {config.TORCH_DTYPE}")
|
20 |
+
# We set use_fast=False since LlamaTokenizerFast takes a long time to init
|
21 |
+
model = AutoDistributedModelForCausalLM.from_pretrained(
|
22 |
+
model_info.repo,
|
23 |
+
active_adapter=model_info.adapter,
|
24 |
+
torch_dtype=config.TORCH_DTYPE,
|
25 |
+
initial_peers=config.INITIAL_PEERS,
|
26 |
+
max_retries=3,
|
27 |
+
)
|
28 |
+
model = model.to(config.DEVICE)
|
29 |
+
|
30 |
+
model_name = model_info.adapter if model_info.adapter is not None else model_info.repo
|
31 |
+
models[model_name] = model, tokenizer
|
32 |
+
|
33 |
+
logger.info("Starting Flask app")
|
34 |
+
app = Flask(__name__)
|
35 |
+
CORS(app)
|
36 |
+
app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25}
|
37 |
+
sock = Sock(app)
|
38 |
+
|
39 |
+
|
40 |
+
@app.route("/")
|
41 |
+
def main_page():
|
42 |
+
return app.send_static_file("index.html")
|
43 |
+
|
44 |
+
|
45 |
+
import http_api
|
46 |
+
import websocket_api
|
config.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from cpufeature import CPUFeature
|
7 |
+
from petals.constants import PUBLIC_INITIAL_PEERS
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class ModelInfo:
|
12 |
+
repo: str
|
13 |
+
adapter: Optional[str] = None
|
14 |
+
|
15 |
+
|
16 |
+
MODELS = [
|
17 |
+
ModelInfo(repo="meta-llama/Llama-2-70b-hf"),
|
18 |
+
ModelInfo(repo="meta-llama/Llama-2-70b-chat-hf"),
|
19 |
+
ModelInfo(repo="enoch/llama-65b-hf"),
|
20 |
+
ModelInfo(repo="enoch/llama-65b-hf", adapter="timdettmers/guanaco-65b"),
|
21 |
+
# ModelInfo(repo="bigscience/bloom"),
|
22 |
+
ModelInfo(repo="bigscience/bloomz"),
|
23 |
+
]
|
24 |
+
DEFAULT_MODEL_NAME = "enoch/llama-65b-hf"
|
25 |
+
|
26 |
+
INITIAL_PEERS = PUBLIC_INITIAL_PEERS
|
27 |
+
# Set this to a list of multiaddrs to connect to a private swarm instead of the public one, for example:
|
28 |
+
# INITIAL_PEERS = ['/ip4/10.1.2.3/tcp/31234/p2p/QmcXhze98AcgGQDDYna23s4Jho96n8wkwLJv78vxtFNq44']
|
29 |
+
|
30 |
+
DEVICE = "cpu"
|
31 |
+
|
32 |
+
if DEVICE == "cuda":
|
33 |
+
TORCH_DTYPE = "auto"
|
34 |
+
elif CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]:
|
35 |
+
TORCH_DTYPE = torch.bfloat16
|
36 |
+
else:
|
37 |
+
TORCH_DTYPE = torch.float32 # You can use bfloat16 in this case too, but it will be slow
|
38 |
+
|
39 |
+
STEP_TIMEOUT = 5 * 60
|
40 |
+
MAX_SESSIONS = 50 # Has effect only for API v1 (HTTP-based)
|
http_api.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from traceback import format_exc
|
2 |
+
|
3 |
+
import hivemind
|
4 |
+
from flask import jsonify, request
|
5 |
+
|
6 |
+
import config
|
7 |
+
from app import app, models
|
8 |
+
from utils import safe_decode
|
9 |
+
|
10 |
+
logger = hivemind.get_logger(__file__)
|
11 |
+
|
12 |
+
|
13 |
+
@app.post("/api/v1/generate")
|
14 |
+
def http_api_generate():
|
15 |
+
try:
|
16 |
+
model_name = get_typed_arg("model", str, config.DEFAULT_MODEL_NAME)
|
17 |
+
inputs = request.values.get("inputs")
|
18 |
+
do_sample = get_typed_arg("do_sample", int, 0)
|
19 |
+
temperature = get_typed_arg("temperature", float, 1.0)
|
20 |
+
top_k = get_typed_arg("top_k", int)
|
21 |
+
top_p = get_typed_arg("top_p", float)
|
22 |
+
max_length = get_typed_arg("max_length", int)
|
23 |
+
max_new_tokens = get_typed_arg("max_new_tokens", int)
|
24 |
+
session_id = request.values.get("session_id")
|
25 |
+
logger.info(f"generate(), model={repr(model_name)}, inputs={repr(inputs)}")
|
26 |
+
|
27 |
+
if session_id is not None:
|
28 |
+
raise RuntimeError(
|
29 |
+
"Reusing inference sessions was removed from HTTP API, please use WebSocket API instead"
|
30 |
+
)
|
31 |
+
|
32 |
+
model, tokenizer = models[model_name]
|
33 |
+
|
34 |
+
if inputs is not None:
|
35 |
+
inputs = tokenizer(inputs, return_tensors="pt")["input_ids"].to(config.DEVICE)
|
36 |
+
n_input_tokens = inputs.shape[1]
|
37 |
+
else:
|
38 |
+
n_input_tokens = 0
|
39 |
+
|
40 |
+
outputs = model.generate(
|
41 |
+
inputs=inputs,
|
42 |
+
do_sample=do_sample,
|
43 |
+
temperature=temperature,
|
44 |
+
top_k=top_k,
|
45 |
+
top_p=top_p,
|
46 |
+
max_length=max_length,
|
47 |
+
max_new_tokens=max_new_tokens,
|
48 |
+
)
|
49 |
+
outputs = safe_decode(tokenizer, outputs[0, n_input_tokens:])
|
50 |
+
logger.info(f"generate(), outputs={repr(outputs)}")
|
51 |
+
|
52 |
+
return jsonify(ok=True, outputs=outputs)
|
53 |
+
except Exception:
|
54 |
+
return jsonify(ok=False, traceback=format_exc())
|
55 |
+
|
56 |
+
|
57 |
+
def get_typed_arg(name, expected_type, default=None):
|
58 |
+
value = request.values.get(name)
|
59 |
+
return expected_type(value) if value is not None else default
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/bigscience-workshop/petals
|
2 |
+
Flask
|
3 |
+
flask-sock
|
4 |
+
flask-cors
|
5 |
+
gunicorn[gthread]
|
utils.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import PreTrainedTokenizerBase
|
3 |
+
|
4 |
+
|
5 |
+
def safe_decode(tokenizer: PreTrainedTokenizerBase, outputs: torch.Tensor):
|
6 |
+
# Workaround to make SentencePiece .decode() keep leading spaces in a token
|
7 |
+
fake_token = tokenizer("^")["input_ids"][0]
|
8 |
+
result = tokenizer.decode([fake_token] + outputs.tolist())
|
9 |
+
|
10 |
+
# We use .lstrip() since SentencePiece may add leading spaces, e.g. if the outputs are "</s>"
|
11 |
+
return result.lstrip()[1:]
|
websocket_api.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from traceback import format_exc
|
3 |
+
|
4 |
+
import flask_sock
|
5 |
+
import hivemind
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import config
|
9 |
+
from app import sock, models
|
10 |
+
from utils import safe_decode
|
11 |
+
|
12 |
+
logger = hivemind.get_logger(__file__)
|
13 |
+
|
14 |
+
|
15 |
+
@sock.route("/api/v2/generate")
|
16 |
+
def ws_api_generate(ws):
|
17 |
+
try:
|
18 |
+
request = json.loads(ws.receive(timeout=config.STEP_TIMEOUT))
|
19 |
+
assert request["type"] == "open_inference_session"
|
20 |
+
model_name = request.get("model")
|
21 |
+
if model_name is None:
|
22 |
+
model_name = config.DEFAULT_MODEL_NAME
|
23 |
+
logger.info(f"ws.generate.open(), model={repr(model_name)}, max_length={repr(request['max_length'])}")
|
24 |
+
|
25 |
+
model, tokenizer = models[model_name]
|
26 |
+
|
27 |
+
with model.inference_session(max_length=request["max_length"]) as session:
|
28 |
+
ws.send(json.dumps({"ok": True}))
|
29 |
+
|
30 |
+
while True:
|
31 |
+
request = json.loads(ws.receive(timeout=config.STEP_TIMEOUT))
|
32 |
+
assert request["type"] == "generate"
|
33 |
+
inputs = request.get("inputs")
|
34 |
+
logger.info(f"ws.generate.step(), inputs={repr(inputs)}")
|
35 |
+
|
36 |
+
if inputs is not None:
|
37 |
+
inputs = tokenizer(inputs, return_tensors="pt")["input_ids"].to(config.DEVICE)
|
38 |
+
n_input_tokens = inputs.shape[1]
|
39 |
+
else:
|
40 |
+
n_input_tokens = 0
|
41 |
+
|
42 |
+
stop_sequence = request.get("stop_sequence")
|
43 |
+
extra_stop_sequences = request.get("extra_stop_sequences")
|
44 |
+
if extra_stop_sequences is not None:
|
45 |
+
cont_token = tokenizer(stop_sequence, return_tensors="pt")["input_ids"].to(config.DEVICE)
|
46 |
+
assert cont_token.shape == (1, 1), \
|
47 |
+
"extra_stop_sequences require stop_sequence length to be exactly 1 token"
|
48 |
+
|
49 |
+
all_outputs = ''
|
50 |
+
delta_q = []
|
51 |
+
stop = False
|
52 |
+
while not stop:
|
53 |
+
outputs = model.generate(
|
54 |
+
inputs=inputs,
|
55 |
+
do_sample=request.get("do_sample", False),
|
56 |
+
temperature=request.get("temperature", 1.0),
|
57 |
+
top_k=request.get("top_k"),
|
58 |
+
top_p=request.get("top_p"),
|
59 |
+
max_length=request.get("max_length"),
|
60 |
+
max_new_tokens=request.get("max_new_tokens"),
|
61 |
+
session=session,
|
62 |
+
)
|
63 |
+
delta = outputs[0, n_input_tokens:].tolist()
|
64 |
+
outputs = safe_decode(tokenizer, torch.Tensor(delta_q + delta))
|
65 |
+
inputs = None # Inputs are passed only for the 1st token of the bot's response
|
66 |
+
n_input_tokens = 0
|
67 |
+
combined = all_outputs + outputs
|
68 |
+
stop = stop_sequence is None or combined.endswith(stop_sequence)
|
69 |
+
if extra_stop_sequences is not None:
|
70 |
+
for seq in extra_stop_sequences:
|
71 |
+
if combined.endswith(seq):
|
72 |
+
stop = True
|
73 |
+
session.last_token_id = cont_token
|
74 |
+
if not stop and outputs[-10:].find(u'\ufffd') > -1:
|
75 |
+
# If there's a replacement character, keep getting more tokens
|
76 |
+
# until we can decode properly
|
77 |
+
delta_q = delta_q + delta
|
78 |
+
logger.info(f"ws.generate.append_retry(), all_outputs={repr(combined)}")
|
79 |
+
else:
|
80 |
+
all_outputs = combined
|
81 |
+
delta_q = []
|
82 |
+
logger.info(f"ws.generate.step(), all_outputs={repr(all_outputs)}, stop={stop}")
|
83 |
+
ws.send(json.dumps({"ok": True, "outputs": outputs, "stop": stop}))
|
84 |
+
except flask_sock.ConnectionClosed:
|
85 |
+
pass
|
86 |
+
except Exception:
|
87 |
+
logger.warning("ws.generate failed:", exc_info=True)
|
88 |
+
ws.send(json.dumps({"ok": False, "traceback": format_exc()}))
|
89 |
+
finally:
|
90 |
+
logger.info(f"ws.generate.close()")
|