import numpy as np


def init(cfg):
    chat_template = cfg['chat_template']
    model = cfg['model']
    s_info = cfg['s_info']
    lock = cfg['session_lock']

    # ========== 预处理 key、desc ==========
    def str_tokenize(s):
        s = model.tokenize((chat_template.nl + s).encode('utf-8'), add_bos=False, special=False)
        if s[0] in chat_template.onenl:
            return s[1:]
        else:
            return s

    text_format = cfg['text_format']
    for x in cfg['btn_status_bar_list']:
        x['key'] = text_format(x['key'],
                               char=cfg['role_char'].value,
                               user=cfg['role_usr'].value)
        x['key_t'] = str_tokenize(x['key'])
        x['desc'] = text_format(x['desc'],
                                char=cfg['role_char'].value,
                                user=cfg['role_usr'].value)
        if x['desc']:
            x['desc_t'] = str_tokenize(x['desc'])

    # ========== 预处理 构造函数 mask ==========
    def btn_status_bar_fn_mask():
        _shape1d = model.scores.shape[-1]
        mask = np.full((_shape1d,), -np.inf, dtype=np.single)
        return mask

    # ========== 预处理 构造函数 数字 ==========
    def btn_status_bar_fn_int(unit: str):
        t_int = str_tokenize('0123456789')
        assert len(t_int) == 10
        fn_int_mask = btn_status_bar_fn_mask()
        fn_int_mask[chat_template.eos] = 0
        fn_int_mask[t_int] = 0
        if unit:
            unit_t = str_tokenize(unit)
            fn_int_mask[unit_t[0]] = 0

        def logits_processor(_input_ids, logits):
            return logits + fn_int_mask

        def inner(eval_t, sample_t):
            retn = []
            while True:
                token = sample_t(logits_processor)
                # ========== 不是数字就结束 ==========
                if token in chat_template.eos:
                    break
                if unit and token == unit_t[0]:
                    break
                # ========== 是数字就继续 ==========
                retn.append(token)
                eval_t([token])

            if unit:
                eval_t(unit_t)  # 添加单位
                retn.extend(unit_t)

            return model.str_detokenize(retn)

        return inner

    # ========== 预处理 构造函数 集合 ==========
    def btn_status_bar_fn_set(value):
        value_t = {_x[0][0]: _x for _x in ((str_tokenize(_y), _y) for _y in value)}
        fn_set_mask = btn_status_bar_fn_mask()
        fn_set_mask[list(value_t.keys())] = 0

        def logits_processor(_input_ids, logits):
            return logits + fn_set_mask

        def inner(eval_t, sample_t):
            token = sample_t(logits_processor)
            eval_t(value_t[token][0])
            return value_t[token][1]

        return inner

    # ========== 预处理 构造函数 字符串 ==========
    def btn_status_bar_fn_str():
        def inner(eval_t, sample_t):
            retn = []
            tmp = ''
            while True:
                token = sample_t(None)
                if token in chat_template.eos:
                    break
                retn.append(token)
                tmp = model.str_detokenize(retn)
                if tmp.endswith('\n') or tmp.endswith('\r'):
                    break
                # ========== 继续 ==========
                eval_t([token])
            return tmp.strip()

        return inner

    # ========== 预处理 value ==========
    for x in cfg['btn_status_bar_list']:
        for y in x['combine']:
            if y['prefix']:
                y['prefix_t'] = str_tokenize(y['prefix'])

            if y['type'] == 'int':
                y['fn'] = btn_status_bar_fn_int(y['unit'])
            elif y['type'] == 'set':
                y['fn'] = btn_status_bar_fn_set(y['value'])
            elif y['type'] == 'str':
                y['fn'] = btn_status_bar_fn_str()
            else:
                pass

    # ========== 添加分隔标记 ==========
    for i, x in enumerate(cfg['btn_status_bar_list']):
        if i == 0:  # 跳过第一个
            continue
        x['key_t'] = chat_template.im_end_nl[-1:] + x['key_t']

    del x  # 避免干扰
    del y

    # print(cfg['btn_status_bar_list'])

    # ========== 输出状态栏 ==========
    def btn_status_bar(_n_keep, _n_discard,
                       _temperature, _repeat_penalty, _frequency_penalty,
                       _presence_penalty, _repeat_last_n, _top_k,
                       _top_p, _min_p, _typical_p,
                       _tfs_z, _mirostat_mode, _mirostat_eta,
                       _mirostat_tau, _usr, _char,
                       _rag, _max_tokens):
        with lock:
            if not cfg['session_active']:
                raise RuntimeError
            if cfg['btn_stop_status']:
                yield [], model.venv_info
                return

            # ========== 临时的eval和sample ==========
            def eval_t(tokens):
                return model.eval_t(
                    tokens=tokens,
                    n_keep=_n_keep,
                    n_discard=_n_discard,
                    im_start=chat_template.im_start_token
                )

            def sample_t(logits_processor):
                return model.sample_t(
                    top_k=_top_k,
                    top_p=_top_p,
                    min_p=_min_p,
                    typical_p=_typical_p,
                    temp=_temperature,
                    repeat_penalty=_repeat_penalty,
                    repeat_last_n=_repeat_last_n,
                    frequency_penalty=_frequency_penalty,
                    presence_penalty=_presence_penalty,
                    tfs_z=_tfs_z,
                    mirostat_mode=_mirostat_mode,
                    mirostat_tau=_mirostat_tau,
                    mirostat_eta=_mirostat_eta,
                    logits_processor=logits_processor
                )

            # ========== 初始化输出模版 ==========
            model.venv_create('status')  # 创建隔离环境
            eval_t(chat_template('状态'))  # 开始标记
            # ========== 流式输出 ==========
            df = []  # 清空
            for _x in cfg['btn_status_bar_list']:
                # ========== 属性 ==========
                df.append([_x['key'], ''])
                eval_t(_x['key_t'])
                if _x['desc']:
                    eval_t(_x['desc_t'])
                yield df, model.venv_info
                # ========== 值 ==========
                for _y in _x['combine']:
                    if _y['prefix']:
                        if df[-1][-1]:
                            df[-1][-1] += _y['prefix']
                        else:
                            df[-1][-1] += _y['prefix'].lstrip(':')
                        eval_t(_y['prefix_t'])
                    df[-1][-1] += _y['fn'](eval_t, sample_t)
                    yield df, model.venv_info
            eval_t(chat_template.im_end_nl)  # 结束标记
            # ========== 清理上一次生成的状态栏 ==========
            model.venv_remove('status', keep_last=1)
            yield df, model.venv_info

    cfg['btn_status_bar_fn'] = {
        'fn': btn_status_bar,
        'inputs': cfg['setting'],
        'outputs': [cfg['status_bar'], s_info]
    }
    cfg['btn_status_bar_fn'].update(cfg['btn_concurrency'])

    cfg['btn_status_bar'].click(
        **cfg['btn_start']
    ).success(
        **cfg['btn_status_bar_fn']
    ).success(
        **cfg['btn_finish']
    )