File size: 3,893 Bytes
d6e0803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

import re
from enum import Enum


class ExtendedModelType(Enum):
    GPT_3_5_TURBO = "gpt-3.5-turbo-1106"
    GPT_3_5_TURBO_16K = "gpt-3.5-turbo-1106"
    INSTRUCT_GPT = "text-davinci-003"
    GPT_3_5_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct"
    GPT_3_5_TURBO_0613 = "gpt-3.5-turbo-0613"
    GPT_3_5_TURBO_16K_0613 = "gpt-3.5-turbo-16k-0613"
    GPT_4 = "gpt-4"
    GPT_4_32K = "gpt-4-32k"
    GPT_4_TURBO = "gpt-4-1106-preview"
    GPT_4_TURBO_VISION = "gpt-4-vision-preview"

    STUB = "stub"

    LLAMA_2 = "llama-2"
    VICUNA = "vicuna"
    VICUNA_16K = "vicuna-16k"

    @property
    def value_for_tiktoken(self) -> str:
        return self.value if self is not ExtendedModelType.STUB else "gpt-3.5-turbo"

    @property
    def is_openai(self) -> bool:
        r"""Returns whether this type of models is an OpenAI-released model."""
        return self in {
            ExtendedModelType.GPT_3_5_TURBO,
            ExtendedModelType.GPT_3_5_TURBO_16K,
            ExtendedModelType.GPT_4,
            ExtendedModelType.GPT_4_32K,
            ExtendedModelType.GPT_4_TURBO,
            ExtendedModelType.GPT_4_TURBO_VISION,
            ExtendedModelType.GPT_3_5_TURBO_0613,
            ExtendedModelType.GPT_3_5_TURBO_16K_0613,
            ExtendedModelType.INSTRUCT_GPT,
            ExtendedModelType.GPT_3_5_TURBO_INSTRUCT,
        }

    @property
    def is_open_source(self) -> bool:
        r"""Returns whether this type of models is open-source."""
        return self in {
            ExtendedModelType.LLAMA_2,
            ExtendedModelType.VICUNA,
            ExtendedModelType.VICUNA_16K,
        }

    @property
    def token_limit(self) -> int:
        r"""Returns the maximum token limit for a given model.
        Returns:
            int: The maximum token limit for the given model.
        """
        if self is ExtendedModelType.GPT_3_5_TURBO:
            return 16385
        elif self is ExtendedModelType.GPT_3_5_TURBO_16K:
            return 16385
        elif self is ExtendedModelType.GPT_4:
            return 8192
        elif self is ExtendedModelType.GPT_4_32K:
            return 32768
        elif self is ExtendedModelType.GPT_4_TURBO:
            return 128000
        elif self is ExtendedModelType.GPT_4_TURBO_VISION:
            return 128000
        elif self is ExtendedModelType.STUB:
            return 4096
        elif self is ExtendedModelType.LLAMA_2:
            return 4096
        elif self is ExtendedModelType.VICUNA:
            # reference: https://lmsys.org/blog/2023-03-30-vicuna/
            return 2048
        elif self is ExtendedModelType.VICUNA_16K:
            return 16384
        elif self is ExtendedModelType.GPT_3_5_TURBO_0613:
            return 4096
        elif self is ExtendedModelType.GPT_3_5_TURBO_16K_0613:
            return 16384
        elif self is ExtendedModelType.INSTRUCT_GPT:
            return 4096
        elif self is ExtendedModelType.GPT_3_5_TURBO_INSTRUCT:
            return 4096
        else:
            raise ValueError("Unknown model type")

    def validate_model_name(self, model_name: str) -> bool:
        r"""Checks whether the model type and the model name matches.

        Args:
            model_name (str): The name of the model, e.g. "vicuna-7b-v1.5".
        Returns:
            bool: Whether the model type mathches the model name.
        """
        if self is ExtendedModelType.VICUNA:
            pattern = r'^vicuna-\d+b-v\d+\.\d+$'
            return bool(re.match(pattern, model_name))
        elif self is ExtendedModelType.VICUNA_16K:
            pattern = r'^vicuna-\d+b-v\d+\.\d+-16k$'
            return bool(re.match(pattern, model_name))
        elif self is ExtendedModelType.LLAMA_2:
            return (self.value in model_name.lower()
                    or "llama2" in model_name.lower())
        else:
            return self.value in model_name.lower()