From 1e4af434f35ad43208e7e5df569c5ff5eb79681b Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 28 Nov 2022 10:45:38 +0400 Subject: [PATCH] Support torch.bfloat16 in hivemind.compression (#524) This PR implements bfloat16 support for `CompressionType.NONE` and `CompressionType.BLOCKWISE_8BIT`. This is important for the Petals client, see https://github.com/bigscience-workshop/petals/issues/79 --- hivemind/compression/base.py | 21 +++++++++++++++------ hivemind/compression/quantization.py | 17 ++++++++++++----- tests/test_compression.py | 24 +++++++++++++++++------- 3 files changed, 44 insertions(+), 18 deletions(-) diff --git a/hivemind/compression/base.py b/hivemind/compression/base.py index 2137869f5..afa8b6b19 100644 --- a/hivemind/compression/base.py +++ b/hivemind/compression/base.py @@ -80,18 +80,27 @@ class NoCompression(CompressionBase): compression_type = runtime_pb2.CompressionType.NONE def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: - array = tensor.detach().numpy() + tensor = tensor.detach() + dtype_name = str(tensor.dtype).lstrip("torch.") + if tensor.dtype == torch.bfloat16: + tensor = tensor.to(torch.float32) + return runtime_pb2.Tensor( compression=self.compression_type, - buffer=array.tobytes(), - size=array.shape, - dtype=array.dtype.name, + buffer=tensor.numpy().tobytes(), + size=tensor.shape, + dtype=dtype_name, requires_grad=tensor.requires_grad, ) def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: - array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)) - return torch.as_tensor(array).reshape(tuple(serialized_tensor.size)) + if serialized_tensor.dtype == "bfloat16": + array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32) + tensor = torch.as_tensor(array, dtype=torch.bfloat16) + else: + array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)) + tensor = torch.as_tensor(array) + return tensor.reshape(tuple(serialized_tensor.size)) def estimate_compression_ratio(self, info: CompressionInfo) -> float: return 1.0 diff --git a/hivemind/compression/quantization.py b/hivemind/compression/quantization.py index 133b2357a..e0c2bd1a4 100644 --- a/hivemind/compression/quantization.py +++ b/hivemind/compression/quantization.py @@ -120,8 +120,8 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz return np.quantile(partition_quantiles, quantiles) -BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly. -Please install it with `pip install bitsandbytes` +BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly. +Please install it with `pip install bitsandbytes` or using the instruction from https://github.com/TimDettmers/bitsandbytes.""" @@ -139,7 +139,12 @@ def quantize( return quantized.numpy(), (absmax.numpy(), codebook.numpy()) def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: - quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace) + tensor = tensor.detach() + dtype_name = str(tensor.dtype).lstrip("torch.") + if tensor.dtype == torch.bfloat16: + tensor = tensor.to(torch.float32) + + quantized, (absmax, codebook) = self.quantize(tensor, allow_inplace=allow_inplace) serialized_data = ( np.int64(len(absmax)).tobytes(), @@ -153,7 +158,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b buffer=b"".join(serialized_data), size=tensor.shape, requires_grad=tensor.requires_grad, - dtype=tensor.numpy().dtype.name, + dtype=dtype_name, compression=self.compression_type, ) @@ -172,6 +177,8 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: codebook = torch.as_tensor(codebook) quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size)) try: - return dequantize_blockwise(quantized, (absmax, codebook)) + result = dequantize_blockwise(quantized, (absmax, codebook)) # Always returns a float32 tensor except NameError: raise ImportError(BNB_MISSING_MESSAGE) + result = result.to(dtype=getattr(torch, serialized_tensor.dtype)) + return result diff --git a/tests/test_compression.py b/tests/test_compression.py index cf34cdde9..6f868c387 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -46,15 +46,18 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008): assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all() +def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024): + serialized_tensor = serialize_torch_tensor(tensor, compression) + chunks = list(split_for_streaming(serialized_tensor, chunk_size)) + assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1 + restored = combine_from_streaming(chunks) + result = deserialize_torch_tensor(restored) + assert torch.allclose(result, tensor, rtol=rtol, atol=atol) + assert result.dtype == tensor.dtype + + @pytest.mark.forked def test_serialize_tensor(): - def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024): - serialized_tensor = serialize_torch_tensor(tensor, compression) - chunks = list(split_for_streaming(serialized_tensor, chunk_size)) - assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1 - restored = combine_from_streaming(chunks) - assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol) - tensor = torch.randn(512, 12288) for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]: _check(tensor, CompressionType.NONE, chunk_size=chunk_size) @@ -65,6 +68,13 @@ def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024): _check(torch.tensor(1.0), CompressionType.FLOAT16) +@pytest.mark.forked +def test_serialize_bfloat16(): + tensor = torch.randn(4096, 16, dtype=torch.bfloat16) + _check(tensor, CompressionType.NONE) + _check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024) + + @pytest.mark.forked def test_allreduce_compression(): """this test ensures that compression works correctly when multiple tensors have different compression types"""