From 539ecf439823586588bca098c3291d2f67ecd15c Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 18 Nov 2024 14:02:50 -0800 Subject: [PATCH] re-fix the sharded to cpu (backport from levanter) (#111) --- src/haliax/_src/state_dict.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/haliax/_src/state_dict.py b/src/haliax/_src/state_dict.py index 10304eb..581b982 100644 --- a/src/haliax/_src/state_dict.py +++ b/src/haliax/_src/state_dict.py @@ -320,6 +320,7 @@ def get_to_cpu(arr): process_mesh = Mesh( np.array(jax.devices()).reshape((jax.process_count(), -1)), ("process", "device") ) + # now we need to find an axis along which we can shard the array. # for this, we need to find an axis s.t. size(axis) % local_devices == 0 @@ -332,7 +333,7 @@ def get_to_cpu(arr): shardings = [None if i != axis_to_shard else "device" for i in range(len(arr.shape))] sharding = NamedSharding(process_mesh, PartitionSpec(*shardings)) - out = jax.jit(lambda x: x, out_shardings=sharding)(arr) + out = jax.device_put(arr, sharding) return np.array(out) elif is_scalarish(arr): return np.asarray(arr)