lavanjv commited on
Commit
b3cc940
1 Parent(s): f3f87ca

Upload 7 files

Browse files
Files changed (7) hide show
  1. Dockerfile +17 -0
  2. app.py +46 -0
  3. config.py +40 -0
  4. http_api.py +59 -0
  5. requirements.txt +5 -0
  6. utils.py +11 -0
  7. websocket_api.py +90 -0
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Step 1: Choose a base image
2
+ FROM python:3.9
3
+
4
+ # Step 2: Set the working directory inside the container
5
+ WORKDIR /app
6
+
7
+ # Step 3: Copy the application files to the container
8
+ COPY . /app
9
+
10
+ # Step 4: Install the required dependencies
11
+ RUN pip install -r requirements.txt
12
+
13
+ # Step 5: Expose the necessary port
14
+ EXPOSE 5000
15
+
16
+ # Step 6: Define the entry point command
17
+ CMD ["flask", "run", "--host=0.0.0.0", "--port=5000"]
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hivemind
2
+ from flask import Flask
3
+ from flask_cors import CORS
4
+ from flask_sock import Sock
5
+ from transformers import AutoTokenizer
6
+
7
+ from petals import AutoDistributedModelForCausalLM
8
+
9
+ import config
10
+
11
+
12
+ logger = hivemind.get_logger(__file__)
13
+
14
+ models = {}
15
+ for model_info in config.MODELS:
16
+ logger.info(f"Loading tokenizer for {model_info.repo}")
17
+ tokenizer = AutoTokenizer.from_pretrained(model_info.repo, add_bos_token=False, use_fast=False)
18
+
19
+ logger.info(f"Loading model {model_info.repo} with adapter {model_info.adapter} and dtype {config.TORCH_DTYPE}")
20
+ # We set use_fast=False since LlamaTokenizerFast takes a long time to init
21
+ model = AutoDistributedModelForCausalLM.from_pretrained(
22
+ model_info.repo,
23
+ active_adapter=model_info.adapter,
24
+ torch_dtype=config.TORCH_DTYPE,
25
+ initial_peers=config.INITIAL_PEERS,
26
+ max_retries=3,
27
+ )
28
+ model = model.to(config.DEVICE)
29
+
30
+ model_name = model_info.adapter if model_info.adapter is not None else model_info.repo
31
+ models[model_name] = model, tokenizer
32
+
33
+ logger.info("Starting Flask app")
34
+ app = Flask(__name__)
35
+ CORS(app)
36
+ app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25}
37
+ sock = Sock(app)
38
+
39
+
40
+ @app.route("/")
41
+ def main_page():
42
+ return app.send_static_file("index.html")
43
+
44
+
45
+ import http_api
46
+ import websocket_api
config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ from cpufeature import CPUFeature
7
+ from petals.constants import PUBLIC_INITIAL_PEERS
8
+
9
+
10
+ @dataclass
11
+ class ModelInfo:
12
+ repo: str
13
+ adapter: Optional[str] = None
14
+
15
+
16
+ MODELS = [
17
+ ModelInfo(repo="meta-llama/Llama-2-70b-hf"),
18
+ ModelInfo(repo="meta-llama/Llama-2-70b-chat-hf"),
19
+ ModelInfo(repo="enoch/llama-65b-hf"),
20
+ ModelInfo(repo="enoch/llama-65b-hf", adapter="timdettmers/guanaco-65b"),
21
+ # ModelInfo(repo="bigscience/bloom"),
22
+ ModelInfo(repo="bigscience/bloomz"),
23
+ ]
24
+ DEFAULT_MODEL_NAME = "enoch/llama-65b-hf"
25
+
26
+ INITIAL_PEERS = PUBLIC_INITIAL_PEERS
27
+ # Set this to a list of multiaddrs to connect to a private swarm instead of the public one, for example:
28
+ # INITIAL_PEERS = ['/ip4/10.1.2.3/tcp/31234/p2p/QmcXhze98AcgGQDDYna23s4Jho96n8wkwLJv78vxtFNq44']
29
+
30
+ DEVICE = "cpu"
31
+
32
+ if DEVICE == "cuda":
33
+ TORCH_DTYPE = "auto"
34
+ elif CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]:
35
+ TORCH_DTYPE = torch.bfloat16
36
+ else:
37
+ TORCH_DTYPE = torch.float32 # You can use bfloat16 in this case too, but it will be slow
38
+
39
+ STEP_TIMEOUT = 5 * 60
40
+ MAX_SESSIONS = 50 # Has effect only for API v1 (HTTP-based)
http_api.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from traceback import format_exc
2
+
3
+ import hivemind
4
+ from flask import jsonify, request
5
+
6
+ import config
7
+ from app import app, models
8
+ from utils import safe_decode
9
+
10
+ logger = hivemind.get_logger(__file__)
11
+
12
+
13
+ @app.post("/api/v1/generate")
14
+ def http_api_generate():
15
+ try:
16
+ model_name = get_typed_arg("model", str, config.DEFAULT_MODEL_NAME)
17
+ inputs = request.values.get("inputs")
18
+ do_sample = get_typed_arg("do_sample", int, 0)
19
+ temperature = get_typed_arg("temperature", float, 1.0)
20
+ top_k = get_typed_arg("top_k", int)
21
+ top_p = get_typed_arg("top_p", float)
22
+ max_length = get_typed_arg("max_length", int)
23
+ max_new_tokens = get_typed_arg("max_new_tokens", int)
24
+ session_id = request.values.get("session_id")
25
+ logger.info(f"generate(), model={repr(model_name)}, inputs={repr(inputs)}")
26
+
27
+ if session_id is not None:
28
+ raise RuntimeError(
29
+ "Reusing inference sessions was removed from HTTP API, please use WebSocket API instead"
30
+ )
31
+
32
+ model, tokenizer = models[model_name]
33
+
34
+ if inputs is not None:
35
+ inputs = tokenizer(inputs, return_tensors="pt")["input_ids"].to(config.DEVICE)
36
+ n_input_tokens = inputs.shape[1]
37
+ else:
38
+ n_input_tokens = 0
39
+
40
+ outputs = model.generate(
41
+ inputs=inputs,
42
+ do_sample=do_sample,
43
+ temperature=temperature,
44
+ top_k=top_k,
45
+ top_p=top_p,
46
+ max_length=max_length,
47
+ max_new_tokens=max_new_tokens,
48
+ )
49
+ outputs = safe_decode(tokenizer, outputs[0, n_input_tokens:])
50
+ logger.info(f"generate(), outputs={repr(outputs)}")
51
+
52
+ return jsonify(ok=True, outputs=outputs)
53
+ except Exception:
54
+ return jsonify(ok=False, traceback=format_exc())
55
+
56
+
57
+ def get_typed_arg(name, expected_type, default=None):
58
+ value = request.values.get(name)
59
+ return expected_type(value) if value is not None else default
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ git+https://github.com/bigscience-workshop/petals
2
+ Flask
3
+ flask-sock
4
+ flask-cors
5
+ gunicorn[gthread]
utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedTokenizerBase
3
+
4
+
5
+ def safe_decode(tokenizer: PreTrainedTokenizerBase, outputs: torch.Tensor):
6
+ # Workaround to make SentencePiece .decode() keep leading spaces in a token
7
+ fake_token = tokenizer("^")["input_ids"][0]
8
+ result = tokenizer.decode([fake_token] + outputs.tolist())
9
+
10
+ # We use .lstrip() since SentencePiece may add leading spaces, e.g. if the outputs are "</s>"
11
+ return result.lstrip()[1:]
websocket_api.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from traceback import format_exc
3
+
4
+ import flask_sock
5
+ import hivemind
6
+ import torch
7
+
8
+ import config
9
+ from app import sock, models
10
+ from utils import safe_decode
11
+
12
+ logger = hivemind.get_logger(__file__)
13
+
14
+
15
+ @sock.route("/api/v2/generate")
16
+ def ws_api_generate(ws):
17
+ try:
18
+ request = json.loads(ws.receive(timeout=config.STEP_TIMEOUT))
19
+ assert request["type"] == "open_inference_session"
20
+ model_name = request.get("model")
21
+ if model_name is None:
22
+ model_name = config.DEFAULT_MODEL_NAME
23
+ logger.info(f"ws.generate.open(), model={repr(model_name)}, max_length={repr(request['max_length'])}")
24
+
25
+ model, tokenizer = models[model_name]
26
+
27
+ with model.inference_session(max_length=request["max_length"]) as session:
28
+ ws.send(json.dumps({"ok": True}))
29
+
30
+ while True:
31
+ request = json.loads(ws.receive(timeout=config.STEP_TIMEOUT))
32
+ assert request["type"] == "generate"
33
+ inputs = request.get("inputs")
34
+ logger.info(f"ws.generate.step(), inputs={repr(inputs)}")
35
+
36
+ if inputs is not None:
37
+ inputs = tokenizer(inputs, return_tensors="pt")["input_ids"].to(config.DEVICE)
38
+ n_input_tokens = inputs.shape[1]
39
+ else:
40
+ n_input_tokens = 0
41
+
42
+ stop_sequence = request.get("stop_sequence")
43
+ extra_stop_sequences = request.get("extra_stop_sequences")
44
+ if extra_stop_sequences is not None:
45
+ cont_token = tokenizer(stop_sequence, return_tensors="pt")["input_ids"].to(config.DEVICE)
46
+ assert cont_token.shape == (1, 1), \
47
+ "extra_stop_sequences require stop_sequence length to be exactly 1 token"
48
+
49
+ all_outputs = ''
50
+ delta_q = []
51
+ stop = False
52
+ while not stop:
53
+ outputs = model.generate(
54
+ inputs=inputs,
55
+ do_sample=request.get("do_sample", False),
56
+ temperature=request.get("temperature", 1.0),
57
+ top_k=request.get("top_k"),
58
+ top_p=request.get("top_p"),
59
+ max_length=request.get("max_length"),
60
+ max_new_tokens=request.get("max_new_tokens"),
61
+ session=session,
62
+ )
63
+ delta = outputs[0, n_input_tokens:].tolist()
64
+ outputs = safe_decode(tokenizer, torch.Tensor(delta_q + delta))
65
+ inputs = None # Inputs are passed only for the 1st token of the bot's response
66
+ n_input_tokens = 0
67
+ combined = all_outputs + outputs
68
+ stop = stop_sequence is None or combined.endswith(stop_sequence)
69
+ if extra_stop_sequences is not None:
70
+ for seq in extra_stop_sequences:
71
+ if combined.endswith(seq):
72
+ stop = True
73
+ session.last_token_id = cont_token
74
+ if not stop and outputs[-10:].find(u'\ufffd') > -1:
75
+ # If there's a replacement character, keep getting more tokens
76
+ # until we can decode properly
77
+ delta_q = delta_q + delta
78
+ logger.info(f"ws.generate.append_retry(), all_outputs={repr(combined)}")
79
+ else:
80
+ all_outputs = combined
81
+ delta_q = []
82
+ logger.info(f"ws.generate.step(), all_outputs={repr(all_outputs)}, stop={stop}")
83
+ ws.send(json.dumps({"ok": True, "outputs": outputs, "stop": stop}))
84
+ except flask_sock.ConnectionClosed:
85
+ pass
86
+ except Exception:
87
+ logger.warning("ws.generate failed:", exc_info=True)
88
+ ws.send(json.dumps({"ok": False, "traceback": format_exc()}))
89
+ finally:
90
+ logger.info(f"ws.generate.close()")