Maxime
commited on
Commit
•
2fe95cd
1
Parent(s):
c1382e7
fix distributed devices (#612)
Browse files* fix distributed devices
* Update distributed.py
* Update distributed.py
src/axolotl/utils/distributed.py
CHANGED
@@ -77,7 +77,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|
77 |
value_scalar = fn()
|
78 |
if not is_distributed():
|
79 |
return [value_scalar]
|
80 |
-
value_tensor = torch.tensor(
|
|
|
|
|
81 |
|
82 |
if not is_main_process():
|
83 |
dist.gather(value_tensor, dst=0)
|
@@ -137,9 +139,13 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
|
|
137 |
"""
|
138 |
if is_main_process():
|
139 |
value_scalar = fn()
|
140 |
-
value_tensor = torch.tensor(
|
|
|
|
|
141 |
else:
|
142 |
-
value_tensor = torch.tensor(
|
|
|
|
|
143 |
|
144 |
# Broadcast the tensor to all processes.
|
145 |
barrier()
|
@@ -164,7 +170,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
|
164 |
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
165 |
"""
|
166 |
value_scalar = fn()
|
167 |
-
value_tensor = torch.tensor(
|
|
|
|
|
168 |
|
169 |
# Placeholder tensor for gathering results
|
170 |
if is_main_process():
|
|
|
77 |
value_scalar = fn()
|
78 |
if not is_distributed():
|
79 |
return [value_scalar]
|
80 |
+
value_tensor = torch.tensor(
|
81 |
+
value_scalar, device=torch.cuda.current_device()
|
82 |
+
).float()
|
83 |
|
84 |
if not is_main_process():
|
85 |
dist.gather(value_tensor, dst=0)
|
|
|
139 |
"""
|
140 |
if is_main_process():
|
141 |
value_scalar = fn()
|
142 |
+
value_tensor = torch.tensor(
|
143 |
+
value_scalar, device=torch.cuda.current_device()
|
144 |
+
).float()
|
145 |
else:
|
146 |
+
value_tensor = torch.tensor(
|
147 |
+
0.0, device=torch.cuda.current_device()
|
148 |
+
) # Placeholder tensor
|
149 |
|
150 |
# Broadcast the tensor to all processes.
|
151 |
barrier()
|
|
|
170 |
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
171 |
"""
|
172 |
value_scalar = fn()
|
173 |
+
value_tensor = torch.tensor(
|
174 |
+
value_scalar, device=torch.cuda.current_device()
|
175 |
+
).float()
|
176 |
|
177 |
# Placeholder tensor for gathering results
|
178 |
if is_main_process():
|