Skip to content

Commit

Permalink
re-fix the sharded to cpu (backport from levanter) (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Nov 18, 2024
1 parent ceb90ce commit 539ecf4
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/haliax/_src/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def get_to_cpu(arr):
process_mesh = Mesh(
np.array(jax.devices()).reshape((jax.process_count(), -1)), ("process", "device")
)

# now we need to find an axis along which we can shard the array.
# for this, we need to find an axis s.t. size(axis) % local_devices == 0

Expand All @@ -332,7 +333,7 @@ def get_to_cpu(arr):

shardings = [None if i != axis_to_shard else "device" for i in range(len(arr.shape))]
sharding = NamedSharding(process_mesh, PartitionSpec(*shardings))
out = jax.jit(lambda x: x, out_shardings=sharding)(arr)
out = jax.device_put(arr, sharding)
return np.array(out)
elif is_scalarish(arr):
return np.asarray(arr)
Expand Down

0 comments on commit 539ecf4

Please sign in to comment.