duzx16
		
	commited on
		
		
					Commit 
							
							·
						
						2e1be30
	
1
								Parent(s):
							
							c949d03
								
Add support for loading quantized model
Browse files- configuration_chatglm.py +3 -0
 - modeling_chatglm.py +30 -4
 - quantization.py +44 -30
 
    	
        configuration_chatglm.py
    CHANGED
    
    | 
         @@ -70,6 +70,7 @@ class ChatGLMConfig(PretrainedConfig): 
     | 
|
| 70 | 
         
             
                        max_sequence_length=2048,
         
     | 
| 71 | 
         
             
                        inner_hidden_size=16384,
         
     | 
| 72 | 
         
             
                        position_encoding_2d=True,
         
     | 
| 
         | 
|
| 73 | 
         
             
                        pre_seq_len=None,
         
     | 
| 74 | 
         
             
                        prefix_projection=False,
         
     | 
| 75 | 
         
             
                        **kwargs
         
     | 
| 
         @@ -86,8 +87,10 @@ class ChatGLMConfig(PretrainedConfig): 
     | 
|
| 86 | 
         
             
                    self.eos_token_id = eos_token_id
         
     | 
| 87 | 
         
             
                    self.pad_token_id = pad_token_id
         
     | 
| 88 | 
         
             
                    self.position_encoding_2d = position_encoding_2d
         
     | 
| 
         | 
|
| 89 | 
         
             
                    self.pre_seq_len = pre_seq_len
         
     | 
| 90 | 
         
             
                    self.prefix_projection = prefix_projection
         
     | 
| 
         | 
|
| 91 | 
         
             
                    super().__init__(
         
     | 
| 92 | 
         
             
                        pad_token_id=pad_token_id,
         
     | 
| 93 | 
         
             
                        bos_token_id=bos_token_id,
         
     | 
| 
         | 
|
| 70 | 
         
             
                        max_sequence_length=2048,
         
     | 
| 71 | 
         
             
                        inner_hidden_size=16384,
         
     | 
| 72 | 
         
             
                        position_encoding_2d=True,
         
     | 
| 73 | 
         
            +
                        quantization_bit=0,
         
     | 
| 74 | 
         
             
                        pre_seq_len=None,
         
     | 
| 75 | 
         
             
                        prefix_projection=False,
         
     | 
| 76 | 
         
             
                        **kwargs
         
     | 
| 
         | 
|
| 87 | 
         
             
                    self.eos_token_id = eos_token_id
         
     | 
| 88 | 
         
             
                    self.pad_token_id = pad_token_id
         
     | 
| 89 | 
         
             
                    self.position_encoding_2d = position_encoding_2d
         
     | 
| 90 | 
         
            +
                    self.quantization_bit = quantization_bit
         
     | 
| 91 | 
         
             
                    self.pre_seq_len = pre_seq_len
         
     | 
| 92 | 
         
             
                    self.prefix_projection = prefix_projection
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
             
                    super().__init__(
         
     | 
| 95 | 
         
             
                        pad_token_id=pad_token_id,
         
     | 
| 96 | 
         
             
                        bos_token_id=bos_token_id,
         
     | 
    	
        modeling_chatglm.py
    CHANGED
    
    | 
         @@ -139,6 +139,7 @@ class PrefixEncoder(torch.nn.Module): 
     | 
|
| 139 | 
         
             
                Input shape: (batch-size, prefix-length)
         
     | 
| 140 | 
         
             
                Output shape: (batch-size, prefix-length, 2*layers*hidden)
         
     | 
| 141 | 
         
             
                """
         
     | 
| 
         | 
|
| 142 | 
         
             
                def __init__(self, config):
         
     | 
| 143 | 
         
             
                    super().__init__()
         
     | 
| 144 | 
         
             
                    self.prefix_projection = config.prefix_projection
         
     | 
| 
         @@ -216,6 +217,13 @@ class RotaryEmbedding(torch.nn.Module): 
     | 
|
| 216 | 
         
             
                        self.cos_cached, self.sin_cached = cos_cached, sin_cached
         
     | 
| 217 | 
         
             
                    return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
         
     | 
| 218 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 219 | 
         | 
| 220 | 
         
             
            def rotate_half(x):
         
     | 
| 221 | 
         
             
                x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
         
     | 
| 
         @@ -931,7 +939,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel): 
     | 
|
| 931 | 
         
             
                                gmask=use_gmask
         
     | 
| 932 | 
         
             
                            )
         
     | 
| 933 | 
         | 
| 934 | 
         
            -
             
     | 
| 935 | 
         
             
                    # [seq_len, batch, hidden_size]
         
     | 
| 936 | 
         
             
                    hidden_states = inputs_embeds.transpose(0, 1)
         
     | 
| 937 | 
         | 
| 
         @@ -999,7 +1006,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): 
     | 
|
| 999 | 
         | 
| 1000 | 
         | 
| 1001 | 
         
             
            class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
         
     | 
| 1002 | 
         
            -
                def __init__(self, config):
         
     | 
| 1003 | 
         
             
                    super().__init__(config)
         
     | 
| 1004 | 
         | 
| 1005 | 
         
             
                    # self.hidden_size = config.hidden_size
         
     | 
| 
         @@ -1019,6 +1026,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): 
     | 
|
| 1019 | 
         
             
                        dtype=torch.half
         
     | 
| 1020 | 
         
             
                    )
         
     | 
| 1021 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1022 | 
         
             
                def get_output_embeddings(self):
         
     | 
| 1023 | 
         
             
                    return self.lm_head
         
     | 
| 1024 | 
         | 
| 
         @@ -1351,7 +1365,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): 
     | 
|
| 1351 | 
         
             
                            break
         
     | 
| 1352 | 
         
             
                        yield input_ids
         
     | 
| 1353 | 
         | 
| 1354 | 
         
            -
                def quantize(self, bits: int):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1355 | 
         
             
                    from .quantization import quantize
         
     | 
| 1356 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1357 | 
         
             
                    return self
         
     | 
| 
         | 
|
| 139 | 
         
             
                Input shape: (batch-size, prefix-length)
         
     | 
| 140 | 
         
             
                Output shape: (batch-size, prefix-length, 2*layers*hidden)
         
     | 
| 141 | 
         
             
                """
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
             
                def __init__(self, config):
         
     | 
| 144 | 
         
             
                    super().__init__()
         
     | 
| 145 | 
         
             
                    self.prefix_projection = config.prefix_projection
         
     | 
| 
         | 
|
| 217 | 
         
             
                        self.cos_cached, self.sin_cached = cos_cached, sin_cached
         
     | 
| 218 | 
         
             
                    return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
         
     | 
| 219 | 
         | 
| 220 | 
         
            +
                def _apply(self, fn):
         
     | 
| 221 | 
         
            +
                    if self.cos_cached is not None:
         
     | 
| 222 | 
         
            +
                        self.cos_cached = fn(self.cos_cached)
         
     | 
| 223 | 
         
            +
                    if self.sin_cached is not None:
         
     | 
| 224 | 
         
            +
                        self.sin_cached = fn(self.sin_cached)
         
     | 
| 225 | 
         
            +
                    return super()._apply(fn)
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         | 
| 228 | 
         
             
            def rotate_half(x):
         
     | 
| 229 | 
         
             
                x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
         
     | 
| 
         | 
|
| 939 | 
         
             
                                gmask=use_gmask
         
     | 
| 940 | 
         
             
                            )
         
     | 
| 941 | 
         | 
| 
         | 
|
| 942 | 
         
             
                    # [seq_len, batch, hidden_size]
         
     | 
| 943 | 
         
             
                    hidden_states = inputs_embeds.transpose(0, 1)
         
     | 
| 944 | 
         | 
| 
         | 
|
| 1006 | 
         | 
| 1007 | 
         | 
| 1008 | 
         
             
            class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
         
     | 
| 1009 | 
         
            +
                def __init__(self, config: ChatGLMConfig):
         
     | 
| 1010 | 
         
             
                    super().__init__(config)
         
     | 
| 1011 | 
         | 
| 1012 | 
         
             
                    # self.hidden_size = config.hidden_size
         
     | 
| 
         | 
|
| 1026 | 
         
             
                        dtype=torch.half
         
     | 
| 1027 | 
         
             
                    )
         
     | 
| 1028 | 
         | 
| 1029 | 
         
            +
                    self.config = config
         
     | 
| 1030 | 
         
            +
             
     | 
| 1031 | 
         
            +
                    self.quantized = False
         
     | 
| 1032 | 
         
            +
             
     | 
| 1033 | 
         
            +
                    if self.config.quantization_bit:
         
     | 
| 1034 | 
         
            +
                        self.quantize(self.config.quantization_bit, empty_init=True)
         
     | 
| 1035 | 
         
            +
             
     | 
| 1036 | 
         
             
                def get_output_embeddings(self):
         
     | 
| 1037 | 
         
             
                    return self.lm_head
         
     | 
| 1038 | 
         | 
| 
         | 
|
| 1365 | 
         
             
                            break
         
     | 
| 1366 | 
         
             
                        yield input_ids
         
     | 
| 1367 | 
         | 
| 1368 | 
         
            +
                def quantize(self, bits: int, empty_init=False, **kwargs):
         
     | 
| 1369 | 
         
            +
                    if bits == 0:
         
     | 
| 1370 | 
         
            +
                        return
         
     | 
| 1371 | 
         
            +
             
     | 
| 1372 | 
         
             
                    from .quantization import quantize
         
     | 
| 1373 | 
         
            +
             
     | 
| 1374 | 
         
            +
                    if self.quantized:
         
     | 
| 1375 | 
         
            +
                        logger.info("Already quantized.")
         
     | 
| 1376 | 
         
            +
                        return self
         
     | 
| 1377 | 
         
            +
             
     | 
| 1378 | 
         
            +
                    self.quantized = True
         
     | 
| 1379 | 
         
            +
             
     | 
| 1380 | 
         
            +
                    self.config.quantization_bit = bits
         
     | 
| 1381 | 
         
            +
             
     | 
| 1382 | 
         
            +
                    self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
         
     | 
| 1383 | 
         
             
                    return self
         
     | 
    	
        quantization.py
    CHANGED
    
    | 
         @@ -5,9 +5,40 @@ import bz2 
     | 
|
| 5 | 
         
             
            import torch
         
     | 
| 6 | 
         
             
            import base64
         
     | 
| 7 | 
         
             
            import ctypes
         
     | 
| 
         | 
|
| 8 | 
         | 
| 9 | 
         
             
            from typing import List
         
     | 
| 10 | 
         
            -
            from  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 11 | 
         | 
| 12 | 
         | 
| 13 | 
         
             
            class W8A16Linear(torch.autograd.Function):
         
     | 
| 
         @@ -33,30 +64,6 @@ class W8A16Linear(torch.autograd.Function): 
     | 
|
| 33 | 
         
             
                    return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
         
     | 
| 34 | 
         | 
| 35 | 
         | 
| 36 | 
         
            -
            class Kernel:
         
     | 
| 37 | 
         
            -
                def __init__(self, code: bytes, function_names: List[str]):
         
     | 
| 38 | 
         
            -
                    self.code = code
         
     | 
| 39 | 
         
            -
                    self._function_names = function_names
         
     | 
| 40 | 
         
            -
                    self._cmodule = LazyKernelCModule(self.code)
         
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
                    for name in self._function_names:
         
     | 
| 43 | 
         
            -
                        setattr(self, name, KernelFunction(self._cmodule, name))
         
     | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
            -
            quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
         
     | 
| 47 | 
         
            -
             
     | 
| 48 | 
         
            -
            kernels = Kernel(
         
     | 
| 49 | 
         
            -
                bz2.decompress(base64.b64decode(quantization_code)),
         
     | 
| 50 | 
         
            -
                [
         
     | 
| 51 | 
         
            -
                    "int4WeightCompression",
         
     | 
| 52 | 
         
            -
                    "int4WeightExtractionFloat",
         
     | 
| 53 | 
         
            -
                    "int4WeightExtractionHalf",
         
     | 
| 54 | 
         
            -
                    "int8WeightExtractionFloat",
         
     | 
| 55 | 
         
            -
                    "int8WeightExtractionHalf",
         
     | 
| 56 | 
         
            -
                ],
         
     | 
| 57 | 
         
            -
            )
         
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
             
     | 
| 60 | 
         
             
            def compress_int4_weight(weight: torch.Tensor):  # (n, m)
         
     | 
| 61 | 
         
             
                with torch.cuda.device(weight.device):
         
     | 
| 62 | 
         
             
                    n, m = weight.size(0), weight.size(1)
         
     | 
| 
         @@ -111,18 +118,18 @@ def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, sourc 
     | 
|
| 111 | 
         | 
| 112 | 
         | 
| 113 | 
         
             
            class QuantizedLinear(Linear):
         
     | 
| 114 | 
         
            -
                def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, *args, **kwargs):
         
     | 
| 115 | 
         
             
                    super(QuantizedLinear, self).__init__(*args, **kwargs)
         
     | 
| 116 | 
         
             
                    self.weight_bit_width = weight_bit_width
         
     | 
| 117 | 
         | 
| 118 | 
         
             
                    shape = self.weight.shape
         
     | 
| 119 | 
         
             
                    del self.weight
         
     | 
| 120 | 
         | 
| 121 | 
         
            -
                    if weight_tensor is None:
         
     | 
| 122 | 
         
             
                        self.weight = torch.empty(
         
     | 
| 123 | 
         
             
                            shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
         
     | 
| 124 | 
         
             
                        )
         
     | 
| 125 | 
         
            -
                        self.weight_scale = torch.empty(shape[0], dtype=kwargs[" 
     | 
| 126 | 
         
             
                    else:
         
     | 
| 127 | 
         
             
                        self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
         
     | 
| 128 | 
         
             
                        self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
         
     | 
| 
         @@ -131,7 +138,10 @@ class QuantizedLinear(Linear): 
     | 
|
| 131 | 
         | 
| 132 | 
         
             
                    self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
         
     | 
| 133 | 
         
             
                    self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
         
     | 
| 134 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 135 | 
         | 
| 136 | 
         
             
                def forward(self, input):
         
     | 
| 137 | 
         
             
                    output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
         
     | 
| 
         @@ -140,7 +150,7 @@ class QuantizedLinear(Linear): 
     | 
|
| 140 | 
         
             
                    return output
         
     | 
| 141 | 
         | 
| 142 | 
         | 
| 143 | 
         
            -
            def quantize(model, weight_bit_width):
         
     | 
| 144 | 
         
             
                """Replace fp16 linear with quantized linear"""
         
     | 
| 145 | 
         | 
| 146 | 
         
             
                for layer in model.layers:
         
     | 
| 
         @@ -153,6 +163,7 @@ def quantize(model, weight_bit_width): 
     | 
|
| 153 | 
         
             
                        bias=True,
         
     | 
| 154 | 
         
             
                        dtype=torch.half,
         
     | 
| 155 | 
         
             
                        device=layer.attention.query_key_value.weight.device,
         
     | 
| 
         | 
|
| 156 | 
         
             
                    )
         
     | 
| 157 | 
         
             
                    layer.attention.dense = QuantizedLinear(
         
     | 
| 158 | 
         
             
                        weight_bit_width=weight_bit_width,
         
     | 
| 
         @@ -163,6 +174,7 @@ def quantize(model, weight_bit_width): 
     | 
|
| 163 | 
         
             
                        bias=True,
         
     | 
| 164 | 
         
             
                        dtype=torch.half,
         
     | 
| 165 | 
         
             
                        device=layer.attention.dense.weight.device,
         
     | 
| 
         | 
|
| 166 | 
         
             
                    )
         
     | 
| 167 | 
         
             
                    layer.mlp.dense_h_to_4h = QuantizedLinear(
         
     | 
| 168 | 
         
             
                        weight_bit_width=weight_bit_width,
         
     | 
| 
         @@ -173,6 +185,7 @@ def quantize(model, weight_bit_width): 
     | 
|
| 173 | 
         
             
                        bias=True,
         
     | 
| 174 | 
         
             
                        dtype=torch.half,
         
     | 
| 175 | 
         
             
                        device=layer.mlp.dense_h_to_4h.weight.device,
         
     | 
| 
         | 
|
| 176 | 
         
             
                    )
         
     | 
| 177 | 
         
             
                    layer.mlp.dense_4h_to_h = QuantizedLinear(
         
     | 
| 178 | 
         
             
                        weight_bit_width=weight_bit_width,
         
     | 
| 
         @@ -183,5 +196,6 @@ def quantize(model, weight_bit_width): 
     | 
|
| 183 | 
         
             
                        bias=True,
         
     | 
| 184 | 
         
             
                        dtype=torch.half,
         
     | 
| 185 | 
         
             
                        device=layer.mlp.dense_4h_to_h.weight.device,
         
     | 
| 
         | 
|
| 186 | 
         
             
                    )
         
     | 
| 187 | 
         
             
                return model
         
     | 
| 
         | 
|
| 5 | 
         
             
            import torch
         
     | 
| 6 | 
         
             
            import base64
         
     | 
| 7 | 
         
             
            import ctypes
         
     | 
| 8 | 
         
            +
            from transformers.utils import logging
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            from typing import List
         
     | 
| 11 | 
         
            +
            from functools import partial
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            try:
         
     | 
| 16 | 
         
            +
                from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                class Kernel:
         
     | 
| 19 | 
         
            +
                    def __init__(self, code: bytes, function_names: List[str]):
         
     | 
| 20 | 
         
            +
                        self.code = code
         
     | 
| 21 | 
         
            +
                        self._function_names = function_names
         
     | 
| 22 | 
         
            +
                        self._cmodule = LazyKernelCModule(self.code)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                        for name in self._function_names:
         
     | 
| 25 | 
         
            +
                            setattr(self, name, KernelFunction(self._cmodule, name))
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                kernels = Kernel(
         
     | 
| 30 | 
         
            +
                    bz2.decompress(base64.b64decode(quantization_code)),
         
     | 
| 31 | 
         
            +
                    [
         
     | 
| 32 | 
         
            +
                        "int4WeightCompression",
         
     | 
| 33 | 
         
            +
                        "int4WeightExtractionFloat",
         
     | 
| 34 | 
         
            +
                        "int4WeightExtractionHalf",
         
     | 
| 35 | 
         
            +
                        "int8WeightExtractionFloat",
         
     | 
| 36 | 
         
            +
                        "int8WeightExtractionHalf",
         
     | 
| 37 | 
         
            +
                    ],
         
     | 
| 38 | 
         
            +
                )
         
     | 
| 39 | 
         
            +
            except Exception as exception:
         
     | 
| 40 | 
         
            +
                kernels = None
         
     | 
| 41 | 
         
            +
                logger.warning("Failed to load cpm_kernels:" + str(exception))
         
     | 
| 42 | 
         | 
| 43 | 
         | 
| 44 | 
         
             
            class W8A16Linear(torch.autograd.Function):
         
     | 
| 
         | 
|
| 64 | 
         
             
                    return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
         
     | 
| 65 | 
         | 
| 66 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 67 | 
         
             
            def compress_int4_weight(weight: torch.Tensor):  # (n, m)
         
     | 
| 68 | 
         
             
                with torch.cuda.device(weight.device):
         
     | 
| 69 | 
         
             
                    n, m = weight.size(0), weight.size(1)
         
     | 
| 
         | 
|
| 118 | 
         | 
| 119 | 
         | 
| 120 | 
         
             
            class QuantizedLinear(Linear):
         
     | 
| 121 | 
         
            +
                def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs):
         
     | 
| 122 | 
         
             
                    super(QuantizedLinear, self).__init__(*args, **kwargs)
         
     | 
| 123 | 
         
             
                    self.weight_bit_width = weight_bit_width
         
     | 
| 124 | 
         | 
| 125 | 
         
             
                    shape = self.weight.shape
         
     | 
| 126 | 
         
             
                    del self.weight
         
     | 
| 127 | 
         | 
| 128 | 
         
            +
                    if weight_tensor is None or empty_init:
         
     | 
| 129 | 
         
             
                        self.weight = torch.empty(
         
     | 
| 130 | 
         
             
                            shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
         
     | 
| 131 | 
         
             
                        )
         
     | 
| 132 | 
         
            +
                        self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
         
     | 
| 133 | 
         
             
                    else:
         
     | 
| 134 | 
         
             
                        self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
         
     | 
| 135 | 
         
             
                        self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
         
     | 
| 
         | 
|
| 138 | 
         | 
| 139 | 
         
             
                    self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
         
     | 
| 140 | 
         
             
                    self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
         
     | 
| 141 | 
         
            +
                    if bias_tensor is not None:
         
     | 
| 142 | 
         
            +
                        self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False)
         
     | 
| 143 | 
         
            +
                    else:
         
     | 
| 144 | 
         
            +
                        self.bias = None
         
     | 
| 145 | 
         | 
| 146 | 
         
             
                def forward(self, input):
         
     | 
| 147 | 
         
             
                    output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
         
     | 
| 
         | 
|
| 150 | 
         
             
                    return output
         
     | 
| 151 | 
         | 
| 152 | 
         | 
| 153 | 
         
            +
            def quantize(model, weight_bit_width, empty_init=False, **kwargs):
         
     | 
| 154 | 
         
             
                """Replace fp16 linear with quantized linear"""
         
     | 
| 155 | 
         | 
| 156 | 
         
             
                for layer in model.layers:
         
     | 
| 
         | 
|
| 163 | 
         
             
                        bias=True,
         
     | 
| 164 | 
         
             
                        dtype=torch.half,
         
     | 
| 165 | 
         
             
                        device=layer.attention.query_key_value.weight.device,
         
     | 
| 166 | 
         
            +
                        empty_init=empty_init
         
     | 
| 167 | 
         
             
                    )
         
     | 
| 168 | 
         
             
                    layer.attention.dense = QuantizedLinear(
         
     | 
| 169 | 
         
             
                        weight_bit_width=weight_bit_width,
         
     | 
| 
         | 
|
| 174 | 
         
             
                        bias=True,
         
     | 
| 175 | 
         
             
                        dtype=torch.half,
         
     | 
| 176 | 
         
             
                        device=layer.attention.dense.weight.device,
         
     | 
| 177 | 
         
            +
                        empty_init=empty_init
         
     | 
| 178 | 
         
             
                    )
         
     | 
| 179 | 
         
             
                    layer.mlp.dense_h_to_4h = QuantizedLinear(
         
     | 
| 180 | 
         
             
                        weight_bit_width=weight_bit_width,
         
     | 
| 
         | 
|
| 185 | 
         
             
                        bias=True,
         
     | 
| 186 | 
         
             
                        dtype=torch.half,
         
     | 
| 187 | 
         
             
                        device=layer.mlp.dense_h_to_4h.weight.device,
         
     | 
| 188 | 
         
            +
                        empty_init=empty_init
         
     | 
| 189 | 
         
             
                    )
         
     | 
| 190 | 
         
             
                    layer.mlp.dense_4h_to_h = QuantizedLinear(
         
     | 
| 191 | 
         
             
                        weight_bit_width=weight_bit_width,
         
     | 
| 
         | 
|
| 196 | 
         
             
                        bias=True,
         
     | 
| 197 | 
         
             
                        dtype=torch.half,
         
     | 
| 198 | 
         
             
                        device=layer.mlp.dense_4h_to_h.weight.device,
         
     | 
| 199 | 
         
            +
                        empty_init=empty_init
         
     | 
| 200 | 
         
             
                    )
         
     | 
| 201 | 
         
             
                return model
         
     |