from dataclasses import dataclass, fields from typing import Optional PRECISION_TO_BYTES = {"fp32": 4, "fp16": 2, "bf16": 2, "int8": 1, "int4": 0.5} @dataclass class ModelConfig: model_size: float hidden_size: int sequence_length: int total_sequence_length: int # for inference = prompt + output tokens num_layers: int num_heads: int mixed_precision: bool = False precision: str = "bf16" repo_id: Optional[str] = None def overwrite_with_hf_config(self, config: dict): self.model_size = round(get_model_size_from_config(config) / 10**9, 2) self.hidden_size = config["hidden_size"] self.sequence_length = config["max_position_embeddings"] if self.total_sequence_length == 0: self.total_sequence_length = self.sequence_length self.num_layers = config["num_hidden_layers"] self.num_heads = config["num_attention_heads"] @dataclass class TrainingConfig: micro_batch_size: int num_gpus: int optimizer: str zero_stage: int qlora: bool = False gradient_checkpointing: bool = False train: bool = True # False for inference # Utility function to filter params based on dataclass fields def filter_params_for_dataclass(dataclass_type, params): return {field.name: params[field.name] for field in fields(dataclass_type) if field.name in params} def get_model_size_from_config(config: dict): # Embedding parameters: embedding_params = config["vocab_size"] * config["hidden_size"] # Transformer layer parameters def transformer_layer_params(hidden_size, intermediate_size, num_key_value_heads): input_layernorm_params = hidden_size mlp_down_proj_params = hidden_size * intermediate_size mlp_gate_proj_params = intermediate_size * hidden_size mlp_up_proj_params = intermediate_size * hidden_size post_attention_layernorm_params = hidden_size self_attn_k_proj_params = (hidden_size // (num_key_value_heads // 2)) * hidden_size self_attn_o_proj_params = hidden_size * hidden_size self_attn_q_proj_params = hidden_size * hidden_size self_attn_v_proj_params = (hidden_size // (num_key_value_heads // 2)) * hidden_size total_layer_params = ( input_layernorm_params + mlp_down_proj_params + mlp_gate_proj_params + mlp_up_proj_params + post_attention_layernorm_params + self_attn_k_proj_params + self_attn_o_proj_params + self_attn_q_proj_params + self_attn_v_proj_params ) return total_layer_params # Total parameters for all transformer layers single_layer_params = transformer_layer_params(config["hidden_size"], config["intermediate_size"], config["num_key_value_heads"]) total_transformer_params = config["num_hidden_layers"] * single_layer_params # Output layer parameters output_params = config["vocab_size"] * config["hidden_size"] # Total parameters total_params = embedding_params + total_transformer_params + output_params return total_params def model_memory(parameters, precision = "bf16", mixed_precision = False): if mixed_precision: return parameters * (PRECISION_TO_BYTES["fp32"] + PRECISION_TO_BYTES["fp16"]) return parameters * PRECISION_TO_BYTES[precision] def gradients_memory(parameters, precision = "fp32"): return parameters * PRECISION_TO_BYTES[precision] def optimizer_memory(parameters, optimizer= "adamw", precision = "fp32"): optimizer_choices = { "adam": 3, # Adam: stores precision copies of the optimizer parameters, momentum, and variance -> 4 + 4 + 4 = 12 bytes per model parameter "adamw": 3, # AdamW: Same for Adam "sgd": 2, # For SGD: optimier parameters and gradients -> 4 + 4 = 8 bytes per model parameter "adamw_8bit": 1.5, # Adam 8-bit: same for Adam-> 2 + 2 + 2 = 6 bytes per model parameter } return optimizer_choices[optimizer] * parameters * PRECISION_TO_BYTES[precision] # def activations_memory_per_layer(sequence_length, micro_batch_size, hidden_size, num_heads): # bytes_per_layer = sequence_length * micro_batch_size * hidden_size * (34 + 5 * (num_heads * sequence_length / hidden_size)) # return bytes_per_layer / 10**9 def activations_memory_per_layer(sequence_length, micro_batch_size, hidden_size, num_heads): precision = "fp32" "Returns amount of GPU VRAM (in GB) required to store intermediate activations for traditional Transformer Encoder block" mem_bytes = PRECISION_TO_BYTES[precision] * sequence_length * micro_batch_size * hidden_size * ( 16 + 2/PRECISION_TO_BYTES[precision] + 2*num_heads*sequence_length/hidden_size + num_heads*sequence_length/(PRECISION_TO_BYTES[precision]*hidden_size)) return round(mem_bytes / 10**9, 2) def activations_memory(num_layers, sequence_length, micro_batch_size, hidden_size, num_heads): # Reference: https://arxiv.org/pdf/2205.05198 # Activations assumed to be in 16-bit floating precision bytes_per_layer = activations_memory_per_layer(sequence_length, micro_batch_size, hidden_size, num_heads) bytes_model = bytes_per_layer * num_layers return bytes_model def kv_cache_memory(batch_size, total_sequence_length, num_layers, num_heads, hidden_size, precision): # Total sequence length means input prompt length + completion so we assume the context size of the model as upper bound kv_cache_memory = 2 * batch_size * total_sequence_length * num_layers * num_heads * hidden_size * PRECISION_TO_BYTES[precision] return kv_cache_memory / 10**9