File size: 3,981 Bytes
6dc0c9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Benchmarking script to test the throughput of serving workers."""
import argparse
import json

import requests
import threading
import time

from fastchat.conversation import get_conv_template


def main():
    if args.worker_address:
        worker_addr = args.worker_address
    else:
        controller_addr = args.controller_address
        ret = requests.post(controller_addr + "/refresh_all_workers")
        ret = requests.post(controller_addr + "/list_models")
        models = ret.json()["models"]
        models.sort()
        print(f"Models: {models}")

        ret = requests.post(
            controller_addr + "/get_worker_address", json={"model": args.model_name}
        )
        worker_addr = ret.json()["address"]
        print(f"worker_addr: {worker_addr}")

    if worker_addr == "":
        return

    conv = get_conv_template("vicuna_v1.1")
    conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words")
    prompt_template = conv.get_prompt()
    prompts = [prompt_template for _ in range(args.n_thread)]

    headers = {"User-Agent": "fastchat Client"}
    ploads = [
        {
            "model": args.model_name,
            "prompt": prompts[i],
            "max_new_tokens": args.max_new_tokens,
            "temperature": 0.0,
            # "stop": conv.sep,
        }
        for i in range(len(prompts))
    ]

    def send_request(results, i):
        if args.test_dispatch:
            ret = requests.post(
                controller_addr + "/get_worker_address", json={"model": args.model_name}
            )
            thread_worker_addr = ret.json()["address"]
        else:
            thread_worker_addr = worker_addr
        print(f"thread {i} goes to {thread_worker_addr}")
        response = requests.post(
            thread_worker_addr + "/worker_generate_stream",
            headers=headers,
            json=ploads[i],
            stream=False,
        )
        k = list(
            response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")
        )
        # print(k)
        response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
        error_code = json.loads(k[-2].decode("utf-8"))["error_code"]
        # print(f"=== Thread {i} ===, words: {1}, error code: {error_code}")
        results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" "))

    # use N threads to prompt the backend
    tik = time.time()
    threads = []
    results = [None] * args.n_thread
    for i in range(args.n_thread):
        t = threading.Thread(target=send_request, args=(results, i))
        t.start()
        # time.sleep(0.5)
        threads.append(t)

    for t in threads:
        t.join()

    print(f"Time (POST): {time.time() - tik} s")
    # n_words = 0
    # for i, response in enumerate(results):
    #     # print(prompt[i].replace(conv.sep, "\n"), end="")
    #     # make sure the streaming finishes at EOS or stopping criteria
    #     k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
    #     response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
    #     # print(response_new_words)
    #     n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))
    n_words = sum(results)
    time_seconds = time.time() - tik
    print(
        f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, "
        f"throughput: {n_words / time_seconds} words/s."
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--controller-address", type=str, default="http://localhost:21001"
    )
    parser.add_argument("--worker-address", type=str)
    parser.add_argument("--model-name", type=str, default="vicuna")
    parser.add_argument("--max-new-tokens", type=int, default=2048)
    parser.add_argument("--n-thread", type=int, default=8)
    parser.add_argument("--test-dispatch", action="store_true")
    args = parser.parse_args()

    main()