File size: 3,964 Bytes
a905447
ffa1281
 
 
 
 
 
a905447
 
ffa1281
a905447
 
 
 
 
ffa1281
 
a905447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffa1281
a905447
 
 
 
 
 
 
 
 
 
ffa1281
 
 
 
1023f27
ffa1281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a905447
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

from vram_helpers import activations_memory_per_layer, \
                         model_memory, \
                         gradients_memory, \
                         optimizer_memory, \
                         activations_memory, \
                         kv_cache_memory 


def training_vram_required(model_config, training_config):
    # Reference: https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/

    trainable_parameters = model_config.model_size
    if training_config.qlora:
        model_config.precision = "int4"
        # 0.2% according to LoRA paper (https://arxiv.org/pdf/2106.09685)
        trainable_parameters = 0.0002 * model_config.model_size

    model_vram = model_memory(parameters=trainable_parameters,
                              precision=model_config.precision, 
                              mixed_precision=model_config.mixed_precision)
    
    gradients_vram = gradients_memory(parameters=trainable_parameters)
    optimizer_vram = optimizer_memory(parameters=trainable_parameters, optimizer=training_config.optimizer)

    # Baseline
    if training_config.zero_stage == 0:
        pass
    # Optimizer state partitioning
    if training_config.zero_stage >= 1:
        optimizer_vram = optimizer_vram / training_config.num_gpus
    # Gradient + Optimzer state partitioning
    if training_config.zero_stage >= 2:
        gradients_vram = gradients_vram / training_config.num_gpus
    # Parameter partitioning + Gradient + Optimizer partitioning
    if training_config.zero_stage == 3:
        aggregated_vram = model_vram / training_config.num_gpus

    aggregated_vram = model_vram + gradients_vram + optimizer_vram

    activations_vram = activations_memory(model_config.num_layers, 
                                          model_config.sequence_length, 
                                          training_config.micro_batch_size, 
                                          model_config.hidden_size, 
                                          model_config.num_heads)
    
    if training_config.gradient_checkpointing:
        activations_vram = round(activations_vram ** 0.5, 2)
    
    total_vram = aggregated_vram + activations_vram
    return {k: round(v, 2) for k, v in {
        "total": total_vram,
        "model": model_vram,
        "gradients": gradients_vram,
        "optimizer": optimizer_vram,
        "activations": activations_vram
    }.items()}


def inference_vram_required(model_config, training_config):
    model_config.mixed_precision = False
    # Total inference VRAM = model size + KV cache size + activations + additional overhead
    model_vram = model_memory(parameters=model_config.model_size,
                              precision=model_config.precision, 
                              mixed_precision=model_config.mixed_precision)
    kv_cache_vram = kv_cache_memory(batch_size=training_config.micro_batch_size,
                                    total_sequence_length=model_config.total_sequence_length,
                                    num_layers=model_config.num_layers,
                                    num_heads=model_config.num_heads,
                                    hidden_size=model_config.hidden_size,
                                    precision=model_config.precision)
    activations_vram = activations_memory_per_layer(sequence_length=model_config.sequence_length, 
                                          micro_batch_size=training_config.micro_batch_size, 
                                          hidden_size=model_config.hidden_size, 
                                          num_heads=model_config.num_heads)
    total_vram = model_vram + kv_cache_vram + activations_vram
    return {k: round(v, 2) for k, v in {
        "total": total_vram,
        "model": model_vram,
        "kv_cache": kv_cache_vram,
        "activations": activations_vram
    }.items()}