winglian commited on
Commit
f6060a6
1 Parent(s): a4e1bb6

Model parallel (#538)

Browse files

* model-parallel for single process

* fix device/device_map

* fix handling for device

src/axolotl/utils/bench.py CHANGED
@@ -28,7 +28,7 @@ def gpu_memory_usage_smi(device=0):
28
 
29
 
30
  def log_gpu_memory_usage(log, msg, device):
31
- if not torch.cuda.is_available():
32
  return (0, 0, 0)
33
 
34
  usage, cache, misc = gpu_memory_usage_all(device)
 
28
 
29
 
30
  def log_gpu_memory_usage(log, msg, device):
31
+ if not torch.cuda.is_available() or device == "auto":
32
  return (0, 0, 0)
33
 
34
  usage, cache, misc = gpu_memory_usage_all(device)
src/axolotl/utils/config.py CHANGED
@@ -25,7 +25,9 @@ def choose_device(cfg):
25
  return "cpu"
26
 
27
  cfg.device = get_device()
28
- if cfg.device_map != "auto":
 
 
29
  if cfg.device.startswith("cuda"):
30
  cfg.device_map = {"": cfg.local_rank}
31
  else:
 
25
  return "cpu"
26
 
27
  cfg.device = get_device()
28
+ if cfg.world_size == 1:
29
+ cfg.device_map = "auto"
30
+ else:
31
  if cfg.device.startswith("cuda"):
32
  cfg.device_map = {"": cfg.local_rank}
33
  else: