diff --git a/BUILD b/BUILD index 79453dc..aa27ec3 100644 --- a/BUILD +++ b/BUILD @@ -24,7 +24,7 @@ py_library( "//tensorflow_compression/python/ops:gen_ops", "//tensorflow_compression/python/ops:math_ops", "//tensorflow_compression/python/ops:padding_ops", - "//tensorflow_compression/python/ops:soft_round_ops", + "//tensorflow_compression/python/ops:round_ops", "//tensorflow_compression/python/util:packed_tensors", ], ) diff --git a/tensorflow_compression/__init__.py b/tensorflow_compression/__init__.py index 4e0d801..a40d401 100644 --- a/tensorflow_compression/__init__.py +++ b/tensorflow_compression/__init__.py @@ -35,7 +35,7 @@ from tensorflow_compression.python.ops.gen_ops import * from tensorflow_compression.python.ops.math_ops import * from tensorflow_compression.python.ops.padding_ops import * -from tensorflow_compression.python.ops.soft_round_ops import * +from tensorflow_compression.python.ops.round_ops import * from tensorflow_compression.python.util.packed_tensors import * # pylint: enable=wildcard-import diff --git a/tensorflow_compression/all_tests.py b/tensorflow_compression/all_tests.py index 405bdc5..da75f48 100644 --- a/tensorflow_compression/all_tests.py +++ b/tensorflow_compression/all_tests.py @@ -40,7 +40,7 @@ from tensorflow_compression.python.ops.math_ops_test import * from tensorflow_compression.python.ops.padding_ops_test import * from tensorflow_compression.python.ops.range_coding_ops_test import * -from tensorflow_compression.python.ops.soft_round_ops_test import * +from tensorflow_compression.python.ops.round_ops_test import * from tensorflow_compression.python.util.packed_tensors_test import * # pylint: enable=wildcard-import diff --git a/tensorflow_compression/python/distributions/BUILD b/tensorflow_compression/python/distributions/BUILD index f2c783c..7122e02 100644 --- a/tensorflow_compression/python/distributions/BUILD +++ b/tensorflow_compression/python/distributions/BUILD @@ -65,7 +65,7 @@ py_library( ":deep_factorized", ":helpers", ":uniform_noise", - "//tensorflow_compression/python/ops:soft_round_ops", + "//tensorflow_compression/python/ops:round_ops", ], ) @@ -76,7 +76,6 @@ py_test( deps = [ ":deep_factorized", ":round_adapters", - "//tensorflow_compression/python/ops:soft_round_ops", ], ) diff --git a/tensorflow_compression/python/distributions/deep_factorized.py b/tensorflow_compression/python/distributions/deep_factorized.py index a5faeae..b69ac68 100644 --- a/tensorflow_compression/python/distributions/deep_factorized.py +++ b/tensorflow_compression/python/distributions/deep_factorized.py @@ -239,15 +239,18 @@ def _log_prob(self, inputs): return log_s_logits + log_s_neg_logits + tf.math.log(dlogits) def _quantization_offset(self): - return tf.constant(0, dtype=self.dtype) + return helpers.estimate_tails( + self._logits_cumulative, 0., self.batch_shape_tensor(), self.dtype) def _lower_tail(self, tail_mass): - logits = tf.math.log(tail_mass / 2 / (1. - tail_mass / 2)) + logits = tf.math.log( + tf.cast(tail_mass / 2 / (1. - tail_mass / 2), self.dtype)) return helpers.estimate_tails( self._logits_cumulative, logits, self.batch_shape_tensor(), self.dtype) def _upper_tail(self, tail_mass): - logits = -tf.math.log(tail_mass / 2 / (1. - tail_mass / 2)) + logits = -tf.math.log( + tf.cast(tail_mass / 2 / (1. - tail_mass / 2), self.dtype)) return helpers.estimate_tails( self._logits_cumulative, logits, self.batch_shape_tensor(), self.dtype) diff --git a/tensorflow_compression/python/distributions/deep_factorized_test.py b/tensorflow_compression/python/distributions/deep_factorized_test.py index 23cd684..8d441c0 100644 --- a/tensorflow_compression/python/distributions/deep_factorized_test.py +++ b/tensorflow_compression/python/distributions/deep_factorized_test.py @@ -111,10 +111,6 @@ def test_uniform_is_special_case(self): x = tf.linspace(-1., 1., 10) self.assertAllClose(df.prob(x), [0, 0, 0, 1, 1, 1, 1, 0, 0, 0]) - def test_quantization_offset_is_zero(self): - df = deep_factorized.NoisyDeepFactorized() - self.assertEqual(helpers.quantization_offset(df), 0) - def test_tails_are_in_order(self): df = deep_factorized.NoisyDeepFactorized() lower_tail = helpers.lower_tail(df, 2**-8) diff --git a/tensorflow_compression/python/distributions/helpers.py b/tensorflow_compression/python/distributions/helpers.py index b3bb65b..f662a5a 100644 --- a/tensorflow_compression/python/distributions/helpers.py +++ b/tensorflow_compression/python/distributions/helpers.py @@ -157,9 +157,10 @@ def lower_tail(distribution, tail_mass): tail = distribution.quantile(tail_mass / 2) except NotImplementedError: try: + target = tf.math.log(tf.cast(tail_mass / 2, distribution.dtype)) tail = estimate_tails( - distribution.log_cdf, tf.math.log(tail_mass / 2), - distribution.batch_shape_tensor(), distribution.dtype) + distribution.log_cdf, target, distribution.batch_shape_tensor(), + distribution.dtype) except NotImplementedError: raise NotImplementedError( "`distribution` must implement `_lower_tail()`, `quantile()`, or " @@ -193,8 +194,9 @@ def upper_tail(distribution, tail_mass): tail = distribution.quantile(1 - tail_mass / 2) except NotImplementedError: try: + target = tf.math.log(tf.cast(tail_mass / 2, distribution.dtype)) tail = estimate_tails( - distribution.log_survival_function, tf.math.log(tail_mass / 2), + distribution.log_survival_function, target, distribution.batch_shape_tensor(), distribution.dtype) except NotImplementedError: raise NotImplementedError( diff --git a/tensorflow_compression/python/distributions/round_adapters.py b/tensorflow_compression/python/distributions/round_adapters.py index 30c2a38..a61dd5e 100644 --- a/tensorflow_compression/python/distributions/round_adapters.py +++ b/tensorflow_compression/python/distributions/round_adapters.py @@ -19,7 +19,7 @@ from tensorflow_compression.python.distributions import deep_factorized from tensorflow_compression.python.distributions import helpers from tensorflow_compression.python.distributions import uniform_noise -from tensorflow_compression.python.ops import soft_round_ops +from tensorflow_compression.python.ops import round_ops __all__ = [ @@ -239,10 +239,10 @@ def __init__(self, base, alpha, name="SoftRoundAdapter"): self._alpha = alpha def transform(self, x): - return soft_round_ops.soft_round(x, self._alpha) + return round_ops.soft_round(x, self._alpha) def inverse_transform(self, y): - return soft_round_ops.soft_round_inverse(y, self._alpha) + return round_ops.soft_round_inverse(y, self._alpha) class NoisySoftRoundAdapter(uniform_noise.UniformNoiseAdapter): diff --git a/tensorflow_compression/python/distributions/round_adapters_test.py b/tensorflow_compression/python/distributions/round_adapters_test.py index 37bf65c..1412a7b 100644 --- a/tensorflow_compression/python/distributions/round_adapters_test.py +++ b/tensorflow_compression/python/distributions/round_adapters_test.py @@ -20,7 +20,6 @@ import tensorflow_probability as tfp from tensorflow_compression.python.distributions import deep_factorized from tensorflow_compression.python.distributions import round_adapters -from tensorflow_compression.python.ops import soft_round_ops def _test_log_prob_gradient_is_bounded(self, dist_cls, values, params=()): @@ -42,44 +41,42 @@ class AdaptersTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ("softround_deepfactorized", lambda d: round_adapters.SoftRoundAdapter(d, alpha=5.0), - deep_factorized.DeepFactorized, 0.0), + deep_factorized.DeepFactorized), ("softround_logistic", lambda d: round_adapters.SoftRoundAdapter(d, alpha=5.0), - lambda: tfp.distributions.Logistic(loc=10.3, scale=1.5), - lambda: soft_round_ops.soft_round(0.3, alpha=5.0)), + lambda: tfp.distributions.Logistic(loc=10.3, scale=1.5)), ("softround_normal", lambda d: round_adapters.SoftRoundAdapter(d, alpha=4.0), - lambda: tfp.distributions.Normal(loc=10.4, scale=1.5), - lambda: soft_round_ops.soft_round(0.4, alpha=4.0)), + lambda: tfp.distributions.Normal(loc=10.4, scale=1.5)), ("noisysoftround_deepfactorized", lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0), - deep_factorized.DeepFactorized, 0.0), + deep_factorized.DeepFactorized), ("noisysoftround_logistic", lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0), - lambda: tfp.distributions.Logistic(loc=10, scale=1.5), 0.0), + lambda: tfp.distributions.Logistic(loc=10, scale=1.5)), ("noisysoftround_normal", lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0), - lambda: tfp.distributions.Normal(loc=10, scale=1.5), 0.0), + lambda: tfp.distributions.Normal(loc=10, scale=1.5)), ("round_deepfactorized", round_adapters.RoundAdapter, - lambda: deep_factorized.DeepFactorized(init_scale=1.0), 0.0), + lambda: deep_factorized.DeepFactorized(init_scale=1.0)), ("round_logistic", round_adapters.RoundAdapter, - lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5), 0.0), + lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5)), ("round_normal", round_adapters.RoundAdapter, - lambda: tfp.distributions.Normal(loc=1.5, scale=1.5), 0.0), + lambda: tfp.distributions.Normal(loc=1.5, scale=1.5)), ("noisyround_deepfactorized", round_adapters.NoisyRoundAdapter, - lambda: deep_factorized.DeepFactorized(init_scale=1.0), 0.0), + lambda: deep_factorized.DeepFactorized(init_scale=1.0)), ("noisyround_logistic", round_adapters.NoisyRoundAdapter, - lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5), 0.0), + lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5)), ("noisyround_normal", round_adapters.NoisyRoundAdapter, - lambda: tfp.distributions.Normal(loc=1.5, scale=1.5), 0.0), + lambda: tfp.distributions.Normal(loc=1.5, scale=1.5)), ) - def test_tails_and_offset(self, adapter, distribution, expected_offset): + def test_tails(self, adapter, distribution): dist = adapter(distribution()) lower_tail = dist._lower_tail(2**-8) try: @@ -98,12 +95,6 @@ def test_tails_and_offset(self, adapter, distribution, expected_offset): self.assertLessEqual(right_mass, 2**-8) self.assertGreater(upper_tail, lower_tail) - offset = dist._quantization_offset() - if not isinstance(expected_offset, float): - # We cannot run tf inside the parameterized test declaration, hence - # non-float values are wrapped in a lambda. - expected_offset = expected_offset() - self.assertAllClose(offset, expected_offset) @parameterized.named_parameters( ("softround_logistic", @@ -210,16 +201,11 @@ def test_sampling_works(self): sample = dist.sample((5, 4)) self.assertEqual(sample.shape, (5, 4, 2)) - def test_tails_and_offset_are_in_order(self): + def test_tails_are_in_order(self): dist = self.dist_cls(loc=10, scale=1.5) - offset = dist._quantization_offset() lower_tail = dist._lower_tail(2**-8) upper_tail = dist._upper_tail(2**-8) self.assertGreater(upper_tail, lower_tail) - if offset: - # If quantization offset is 0.0, it doesn't need to be between the tails. - self.assertGreater(upper_tail, offset) - self.assertGreater(offset, lower_tail) def test_stats_throw_error(self): dist = self.dist_cls(loc=1, scale=2) diff --git a/tensorflow_compression/python/entropy_models/BUILD b/tensorflow_compression/python/entropy_models/BUILD index c09d0f1..230c56c 100644 --- a/tensorflow_compression/python/entropy_models/BUILD +++ b/tensorflow_compression/python/entropy_models/BUILD @@ -23,6 +23,7 @@ py_library( "//tensorflow_compression/python/distributions:helpers", "//tensorflow_compression/python/ops:gen_ops", "//tensorflow_compression/python/ops:math_ops", + "//tensorflow_compression/python/ops:round_ops", ], ) @@ -42,9 +43,9 @@ py_library( srcs_version = "PY3", deps = [ ":continuous_base", - "//tensorflow_compression/python/distributions:helpers", "//tensorflow_compression/python/ops:gen_ops", "//tensorflow_compression/python/ops:math_ops", + "//tensorflow_compression/python/ops:round_ops", ], ) @@ -52,6 +53,7 @@ py_test( name = "continuous_indexed_test", srcs = ["continuous_indexed_test.py"], python_version = "PY3", + shard_count = 5, deps = [ ":continuous_indexed", "//tensorflow_compression/python/distributions:uniform_noise", @@ -63,15 +65,14 @@ py_library( srcs = ["universal.py"], srcs_version = "PY3", deps = [ - ":continuous_batched", - ":continuous_indexed", + ":continuous_base", + "//tensorflow_compression/python/ops:gen_ops", "//tensorflow_compression/python/ops:math_ops", ], ) py_test( name = "universal_test", - timeout = "long", srcs = ["universal_test.py"], python_version = "PY3", shard_count = 3, diff --git a/tensorflow_compression/python/entropy_models/continuous_base.py b/tensorflow_compression/python/entropy_models/continuous_base.py index e30f5f2..60e1796 100644 --- a/tensorflow_compression/python/entropy_models/continuous_base.py +++ b/tensorflow_compression/python/entropy_models/continuous_base.py @@ -38,7 +38,6 @@ class ContinuousEntropyModelBase(tf.Module, metaclass=abc.ABCMeta): @abc.abstractmethod def __init__(self, - prior=None, coding_rank=None, compression=False, stateless=False, @@ -46,21 +45,10 @@ def __init__(self, tail_mass=2**-8, range_coder_precision=12, dtype=None, - prior_shape=None, - cdf=None, - cdf_offset=None, - cdf_length=None, - cdf_max_length=None, laplace_tail_mass=0): """Initializes the instance. Args: - prior: A `tfp.distributions.Distribution` object. A density model fitting - the marginal distribution of the bottleneck data with additive uniform - noise, which is shared a priori between the sender and the receiver. For - best results, the distribution should be flexible enough to have a - unit-width uniform distribution as a special case, since this is the - marginal distribution for bottleneck dimensions that are constant. coding_rank: Integer. Number of innermost dimensions considered a coding unit. Each coding unit is compressed to its own bit string, and the `__call__()` method sums over each coding unit. @@ -80,83 +68,22 @@ def __init__(self, tail_mass: Float. Approximate probability mass which is range encoded with less precision, by using a Golomb-like code. range_coder_precision: Integer. Precision passed to the range coding op. - dtype: Data type of prior. Must be provided when `prior` is omitted. - prior_shape: Batch shape of the prior (dimensions which are not assumed - i.i.d.). Must be provided when `prior` is omitted. - cdf: `tf.Tensor` or `None`. When provided, is used for range coding rather - than tables built from the prior. - cdf_offset: `tf.Tensor` or `None`. Must be provided along with `cdf`. - cdf_length: `tf.Tensor` or `None`. Must be provided along with `cdf`. - cdf_max_length: Maximum `cdf_length`. When provided, an empty range coding - table is created, which can then be restored using `set_weights`. - Requires `compression=True` and `stateless=False`. + dtype: `tf.dtypes.DType`. Data type of this entropy model (i.e. dtype of + prior, decompressed values). laplace_tail_mass: Float. If positive, will augment the prior with a Laplace mixture for training stability. (experimental) """ super().__init__() - - self._prior = prior + self._prior = None # This will be set by subclasses, if appropriate. self._coding_rank = int(coding_rank) self._compression = bool(compression) self._stateless = bool(stateless) self._expected_grads = bool(expected_grads) self._tail_mass = float(tail_mass) self._range_coder_precision = int(range_coder_precision) + self._dtype = tf.as_dtype(dtype) self._laplace_tail_mass = float(laplace_tail_mass) - - if not (prior is not None) == (dtype is None) == (prior_shape is None): - raise ValueError( - "Either `prior` or both `dtype` and `prior_shape` must be provided.") - if prior is None: - self._dtype = tf.as_dtype(dtype) - self._prior_shape = tf.TensorShape(prior_shape) - else: - if prior.event_shape.rank: - raise ValueError( - "`prior` must be a (batch of) scalar distribution(s).") - self._dtype = tf.as_dtype(prior.dtype) - self._prior_shape = tf.TensorShape(prior.batch_shape) - with self.name_scope: - if self.compression: - if not (cdf is None) == (cdf_offset is None) == (cdf_length is None): - raise ValueError( - "Either all or none of `cdf`, `cdf_offset`, and `cdf_length` " - "must be provided.") - if (prior is None) + (cdf_max_length is None) + (cdf is None) != 2: - raise ValueError( - "With `compression=True`, must provide exactly one of `prior`, " - "`cdf`, or `cdf_max_length`.") - if prior is not None: - cdf, cdf_offset, cdf_length = self._build_tables(prior) - elif cdf_max_length is not None: - if self.stateless: - raise ValueError( - "With `stateless=True`, can't provide `cdf_max_length`.") - cdf_max_length = int(cdf_max_length) - context_size = int(self.context_shape.num_elements()) - zeros = tf.zeros([context_size, cdf_max_length], dtype=tf.int32) - cdf = zeros - cdf_offset = zeros[:, 0] - cdf_length = zeros[:, 0] - if self.stateless: - self._cdf = tf.convert_to_tensor(cdf, dtype=tf.int32, name="cdf") - self._cdf_offset = tf.convert_to_tensor( - cdf_offset, dtype=tf.int32, name="cdf_offset") - self._cdf_length = tf.convert_to_tensor( - cdf_length, dtype=tf.int32, name="cdf_length") - else: - self._cdf = tf.Variable( - cdf, dtype=tf.int32, trainable=False, name="cdf") - self._cdf_offset = tf.Variable( - cdf_offset, dtype=tf.int32, trainable=False, name="cdf_offset") - self._cdf_length = tf.Variable( - cdf_length, dtype=tf.int32, trainable=False, name="cdf_length") - else: - if not (cdf is None and cdf_offset is None and cdf_length is None and - cdf_max_length is None): - raise ValueError("CDFs can't be provided with `compression=False`") - self._laplace_prior = (tfp.distributions.Laplace(loc=0.0, scale=1.0) if laplace_tail_mass else None) @@ -210,31 +137,6 @@ def laplace_tail_mass(self): """Whether to augment the prior with a Laplace mixture.""" return self._laplace_tail_mass - @property - def prior_shape(self): - """Batch shape of `prior` (dimensions which are not assumed i.i.d.).""" - return self._prior_shape - - @property - def prior_shape_tensor(self): - """Batch shape of `prior` as a `Tensor`.""" - return tf.constant(self.prior_shape.as_list(), dtype=tf.int32) - - @property - def context_shape(self): - """The shape of the non-flattened PDF/CDF tables for range coding. - - This is typically the same as the prior shape, but can differ e.g. in - universal entropy models. In any case, the context_shape contains the prior - shape (in the trailing dimensions). - """ - return self.prior_shape - - @property - def context_shape_tensor(self): - """The context shape as a `Tensor`.""" - return tf.constant(self.context_shape.as_list(), dtype=tf.int32) - @property def coding_rank(self): """Number of innermost dimensions considered a coding unit.""" @@ -260,26 +162,63 @@ def range_coder_precision(self): """Precision passed to range coding op.""" return self._range_coder_precision - @tf.custom_gradient - def _quantize_no_offset(self, inputs): - return tf.round(inputs), lambda x: x + def _init_compression(self, cdf, cdf_offset, cdf_length, cdf_shape): + """Sets up this entropy model for using the range coder. - @tf.custom_gradient - def _quantize_offset(self, inputs, offset): - return tf.round(inputs - offset) + offset, lambda x: (x, None) + This is done by storing `cdf`, `cdf_offset`, and `cdf_length` in + `tf.Variable`s (`stateless=False`) or `tf.Tensor`s (`stateless=True`) as + attributes of this object, or creating the variables as placeholders if + `cdf_shape` is provided. - def _quantize(self, inputs, offset=None): - if offset is None: - outputs = self._quantize_no_offset(inputs) - else: - outputs = self._quantize_offset(inputs, offset) - return outputs + The reason for pre-computing the tables is that they must not be + re-generated independently on the sending and receiving side, since small + numerical discrepancies between both sides can occur in this process. If the + tables differ slightly, this in turn would very likely cause catastrophic + error propagation during range decoding. For a more in-depth discussion of + this, see: - def _offset_from_prior(self, prior): - """Computes quantization offset from the prior distribution.""" - return helpers.quantization_offset(prior) + > "Integer Networks for Data Compression with Latent-Variable Models"
+ > J. Ballé, N. Johnston, D. Minnen
+ > https://openreview.net/forum?id=S1zz2i0cY7 - def _build_tables(self, prior): + Args: + cdf: CDF table for range coder. + cdf_offset: CDF offset table for range coder. + cdf_length: CDF length table for range coder. + cdf_shape: Iterable of 2 integers, the shape of `cdf`. Mutually exclusive + with the other three arguments. If provided, creates placeholder values + for them. + """ + if not ((cdf is None) == (cdf_offset is None) == (cdf_length is None) == + (cdf_shape is not None)): + raise ValueError( + "Either all of `cdf`, `cdf_offset`, and `cdf_length`; or `cdf_shape` " + "must be provided.") + if cdf_shape is not None: + if self.stateless: + raise ValueError("With `stateless=True`, can't provide `cdf_shape`.") + cdf_shape = tuple(map(int, cdf_shape)) + if len(cdf_shape) != 2: + raise ValueError("`cdf_shape` must consist of 2 integers.") + zeros = tf.zeros(cdf_shape, dtype=tf.int32) + cdf = zeros + cdf_offset = zeros[:, 0] + cdf_length = zeros[:, 0] + if self.stateless: + self._cdf = tf.convert_to_tensor(cdf, dtype=tf.int32, name="cdf") + self._cdf_offset = tf.convert_to_tensor( + cdf_offset, dtype=tf.int32, name="cdf_offset") + self._cdf_length = tf.convert_to_tensor( + cdf_length, dtype=tf.int32, name="cdf_length") + else: + self._cdf = tf.Variable( + cdf, dtype=tf.int32, trainable=False, name="cdf") + self._cdf_offset = tf.Variable( + cdf_offset, dtype=tf.int32, trainable=False, name="cdf_offset") + self._cdf_length = tf.Variable( + cdf_length, dtype=tf.int32, trainable=False, name="cdf_length") + + def _build_tables(self, prior, offset=None, context_shape=None): """Computes integer-valued probability tables used by the range coder. These tables must not be re-generated independently on the sending and @@ -292,20 +231,23 @@ def _build_tables(self, prior): > J. Ballé, N. Johnston, D. Minnen
> https://openreview.net/forum?id=S1zz2i0cY7 - The tables are stored in `tf.Variable`s (`stateless=False`) or `tf.Tensor`s - (`stateless=True`) as attributes of this object. The recommended way is to - train the model with `compression=False`, then instantiate an entropy model - with `compression=True`, and distribute it to a sender and a receiver. - Args: prior: The `tfp.distributions.Distribution` object (see initializer). + offset: Quantization offsets to use for sampling prior probabilities. + Defaults to 0. + context_shape: Shape of innermost dimensions to evaluate the prior on. + Defaults to and must include `prior.batch_shape`. Returns: CDF table, CDF offsets, CDF lengths. """ - # TODO(jonycgn, relational): Consider not using offset when soft quantization - # is used. - offset = self._offset_from_prior(prior) + if offset is None: + offset = 0. + if context_shape is None: + context_shape = tf.TensorShape(prior.batch_shape) + # Subclasses should have already caught this, but better be safe. + assert not prior.event_shape.rank + lower_tail = helpers.lower_tail(prior, self.tail_mass) upper_tail = helpers.upper_tail(prior, self.tail_mass) # Integers such that: @@ -324,10 +266,10 @@ def _build_tables(self, prior): if tf.executing_eagerly() and max_length > 2048: logging.warning( "Very wide PMF with %d elements may lead to out of memory issues. " - "Consider priors with smaller dispersion or increasing `tail_mass` " + "Consider priors with smaller variance, or increasing `tail_mass` " "parameter.", int(max_length)) samples = tf.range(tf.cast(max_length, self.dtype), dtype=self.dtype) - samples = tf.reshape(samples, [-1] + self.context_shape.rank * [1]) + samples = tf.reshape(samples, [-1] + context_shape.rank * [1]) samples += pmf_start pmf = prior.prob(samples) @@ -335,10 +277,11 @@ def _build_tables(self, prior): pmf = tf.reshape(pmf, [max_length, -1]) pmf = tf.transpose(pmf) - pmf_length = tf.broadcast_to(pmf_length, self.context_shape_tensor) + context_shape = tf.constant(context_shape.as_list(), dtype=tf.int32) + pmf_length = tf.broadcast_to(pmf_length, context_shape) pmf_length = tf.reshape(pmf_length, [-1]) cdf_length = pmf_length + 2 - cdf_offset = tf.broadcast_to(minima, self.context_shape_tensor) + cdf_offset = tf.broadcast_to(minima, context_shape) cdf_offset = tf.reshape(cdf_offset, [-1]) # Prevent tensors from bouncing back and forth between host and GPU. @@ -349,7 +292,7 @@ def loop_body(args): overflow = tf.math.maximum(1 - tf.reduce_sum(prob, keepdims=True), 0.) prob = tf.concat([prob, overflow], axis=0) cdf = gen_ops.pmf_to_quantized_cdf( - prob, precision=self.range_coder_precision) + tf.cast(prob, tf.float32), precision=self.range_coder_precision) return tf.pad( cdf, [[0, max_length - length]], mode="CONSTANT", constant_values=0) @@ -359,7 +302,7 @@ def loop_body(args): return cdf, cdf_offset, cdf_length - def _log_prob_from_prior(self, prior, bottleneck_perturbed): + def _log_prob(self, prior, bottleneck_perturbed): """Evaluates prior.log_prob(bottleneck + noise).""" if self.laplace_tail_mass: laplace_prior = self._laplace_prior @@ -399,11 +342,8 @@ def get_config(self): expected_grads=self.expected_grads, tail_mass=self.tail_mass, range_coder_precision=self.range_coder_precision, + cdf_shape=tuple(map(int, self.cdf.shape)), dtype=self.dtype.name, - # TODO(jonycgn): pytype thinks TensorShape is not iterable, even though - # it defines __iter__. - prior_shape=tuple(map(int, self.prior_shape)), # pytype:disable=wrong-arg-types - cdf_max_length=int(self.cdf.shape[1]), laplace_tail_mass=self.laplace_tail_mass, ) diff --git a/tensorflow_compression/python/entropy_models/continuous_batched.py b/tensorflow_compression/python/entropy_models/continuous_batched.py index 29a65bb..fbb7e29 100644 --- a/tensorflow_compression/python/entropy_models/continuous_batched.py +++ b/tensorflow_compression/python/entropy_models/continuous_batched.py @@ -20,6 +20,7 @@ from tensorflow_compression.python.entropy_models import continuous_base from tensorflow_compression.python.ops import gen_ops from tensorflow_compression.python.ops import math_ops +from tensorflow_compression.python.ops import round_ops __all__ = [ @@ -97,8 +98,8 @@ def __init__(self, cdf=None, cdf_offset=None, cdf_length=None, - cdf_max_length=None, - non_integer_offsets=True, + cdf_shape=None, + non_integer_offset=True, quantization_offset=None, laplace_tail_mass=0): """Initializes the instance. @@ -131,52 +132,76 @@ def __init__(self, tail_mass: Float. Approximate probability mass which is range encoded with less precision, by using a Golomb-like code. range_coder_precision: Integer. Precision passed to the range coding op. - dtype: Data type of prior. Must be provided when `prior` is omitted. + dtype: `tf.dtypes.DType`. Data type of this entropy model (i.e. dtype of + prior, decompressed values). Must be provided if `prior` is omitted. prior_shape: Batch shape of the prior (dimensions which are not assumed i.i.d.). Must be provided when `prior` is omitted. cdf: `tf.Tensor` or `None`. When provided, is used for range coding rather than tables built from the prior. cdf_offset: `tf.Tensor` or `None`. Must be provided along with `cdf`. cdf_length: `tf.Tensor` or `None`. Must be provided along with `cdf`. - cdf_max_length: Maximum `cdf_length`. When provided, an empty range coding - table is created, which can then be restored using `set_weights`. - Requires `compression=True` and `stateless=False`. - non_integer_offsets: Boolean. Whether to quantize to non-integer offsets + cdf_shape: Shape of `cdf`. When provided, an empty range coding table is + created, which can then be restored using `set_weights`. Requires + `compression=True` and `stateless=False`. + non_integer_offset: Boolean. Whether to quantize to non-integer offsets heuristically determined from mode/median of prior. Set to `False` when using soft quantization during training. quantization_offset: `tf.Tensor` or `None`. If `cdf` is provided and - `non_integer_offsets=True`, must be provided. + `non_integer_offset=True`, must be provided as well. laplace_tail_mass: Float. If positive, will augment the prior with a Laplace mixture for training stability. (experimental) - - Raises: - RuntimeError: when attempting to instantiate an entropy model with - `compression=True` and not in eager execution mode. """ + if not (prior is not None) == (dtype is None) == (prior_shape is None): + raise ValueError( + "Either `prior` or both `dtype` and `prior_shape` must be provided.") + if (prior is None) + (cdf_shape is None) + (cdf is None) != 2: + raise ValueError( + "Must provide exactly one of `prior`, `cdf`, or `cdf_shape`.") + if not compression and not ( + cdf is None and cdf_offset is None and cdf_length is None and + cdf_shape is None): + raise ValueError("CDFs can't be provided with `compression=False`") + if prior is not None and prior.event_shape.rank: + raise ValueError("`prior` must be a (batch of) scalar distribution(s).") + super().__init__( - prior=prior, coding_rank=coding_rank, compression=compression, stateless=stateless, expected_grads=expected_grads, tail_mass=tail_mass, range_coder_precision=range_coder_precision, - dtype=dtype, - prior_shape=prior_shape, - cdf=cdf, - cdf_offset=cdf_offset, - cdf_length=cdf_length, - cdf_max_length=cdf_max_length, + dtype=dtype if dtype is not None else prior.dtype, laplace_tail_mass=laplace_tail_mass, ) - self._non_integer_offsets = bool(non_integer_offsets) + self._prior = prior + self._non_integer_offset = bool(non_integer_offset) + self._prior_shape = tf.TensorShape( + prior_shape if prior is None else prior.batch_shape) if self.coding_rank < self.prior_shape.rank: raise ValueError("`coding_rank` can't be smaller than `prior_shape`.") with self.name_scope: - if not self.non_integer_offsets: + if quantization_offset is not None: + # If quantization offset is passed in manually, use it. + pass + elif not self.non_integer_offset: + # If not using the offset heuristic, always quantize to integers. quantization_offset = None - elif prior is not None: + elif cdf_shape is not None: + # `cdf_shape` being set indicates that we are using the `SavedModel` + # protocol. So create a placeholder value. + quantization_offset = tf.zeros( + self.prior_shape_tensor, dtype=self.dtype) + elif cdf is not None: + # CDF is passed in manually. So assume the same about the offsets. + if quantization_offset is None: + raise ValueError( + "When providing `cdf` and `non_integer_offset=True`, must also " + "provide `quantization_offset`.") + else: + assert self._prior is not None + # If prior is available, determine offsets from it using the heuristic. quantization_offset = helpers.quantization_offset(self.prior) # Optimization: if the quantization offset is zero, we don't need to # subtract/add it when quantizing, and we don't need to serialize its @@ -187,15 +212,6 @@ def __init__(self, else: quantization_offset = tf.broadcast_to( quantization_offset, self.prior_shape_tensor) - elif cdf_max_length is not None: - quantization_offset = tf.zeros( - self.prior_shape_tensor, dtype=self.dtype) - else: - assert cdf is not None - if quantization_offset is None: - raise ValueError( - "When providing `cdf` and `non_integer_offsets=True`, must also " - "provide `quantization_offset`.") if quantization_offset is None: self._quantization_offset = None elif self.compression and not self.stateless: @@ -205,10 +221,25 @@ def __init__(self, else: self._quantization_offset = tf.convert_to_tensor( quantization_offset, dtype=self.dtype, name="quantization_offset") + if self.compression: + if cdf is None and cdf_shape is None: + cdf, cdf_offset, cdf_length = self._build_tables( + self.prior, offset=quantization_offset) + self._init_compression(cdf, cdf_offset, cdf_length, cdf_shape) + + @property + def prior_shape(self): + """Batch shape of `prior` (dimensions which are not assumed i.i.d.).""" + return self._prior_shape + + @property + def prior_shape_tensor(self): + """Batch shape of `prior` as a `Tensor`.""" + return tf.constant(self.prior_shape.as_list(), dtype=tf.int32) @property - def non_integer_offsets(self): - return self._non_integer_offsets + def non_integer_offset(self): + return self._non_integer_offset @property def quantization_offset(self): @@ -216,15 +247,14 @@ def quantization_offset(self): return None return tf.convert_to_tensor(self._quantization_offset) - def _compute_indexes_and_offset(self, broadcast_shape): + def _compute_indexes(self, broadcast_shape): """Returns the indexes for range coding and the quantization offset.""" # TODO(jonycgn, ssjhv): Investigate broadcasting in range coding op. prior_size = functools.reduce(lambda x, y: x * y, self.prior_shape, 1) indexes = tf.range(prior_size, dtype=tf.int32) indexes = tf.reshape(indexes, self.prior_shape_tensor) - indexes = tf.broadcast_to( + return tf.broadcast_to( indexes, tf.concat([broadcast_shape, self.prior_shape_tensor], 0)) - return indexes, self.quantization_offset @tf.Module.with_name_scope def __call__(self, bottleneck, training=True): @@ -246,7 +276,7 @@ def __call__(self, bottleneck, training=True): `bits` has the same shape as `bottleneck` without the `self.coding_rank` innermost dimensions. """ - log_prob_fn = functools.partial(self._log_prob_from_prior, self.prior) + log_prob_fn = functools.partial(self._log_prob, self.prior) if training: log_probs, bottleneck_perturbed = math_ops.perturb_and_apply( log_prob_fn, bottleneck, expected_grads=self.expected_grads) @@ -264,8 +294,9 @@ def quantize(self, bottleneck): """Quantizes a floating-point bottleneck tensor. The tensor is rounded to integer values potentially shifted by offsets (if - `self.non_integer_offsets==True`). These offsets depend on `self.prior`. For - instance, for a Gaussian distribution, the returned values would be rounded + `self.quantization_offset is not None`). These offsets can depend on + `self.prior`. For instance, for a Gaussian distribution, when + `self.non_integer_offset == True`, the returned values would be rounded to the location of the mode of the distribution plus or minus an integer. The gradient of this rounding operation is overridden with the identity @@ -278,7 +309,7 @@ def quantize(self, bottleneck): Returns: A `tf.Tensor` containing the quantized values. """ - return self._quantize(bottleneck, self.quantization_offset) + return round_ops.round_st(bottleneck, self.quantization_offset) @tf.Module.with_name_scope def compress(self, bottleneck): @@ -286,7 +317,7 @@ def compress(self, bottleneck): Compresses the tensor to bit strings. `bottleneck` is first quantized as in `quantize()`, and then compressed using the probability tables in - `self.cdf` derived from `self.prior`. The quantized tensor can later be + `self.cdf` (derived from `self.prior`). The quantized tensor can later be recovered by calling `decompress()`. The innermost `self.coding_rank` dimensions are treated as one coding unit, @@ -310,7 +341,8 @@ def compress(self, bottleneck): broadcast_shape = coding_shape[ :self.coding_rank - len(self.prior_shape)] - indexes, offset = self._compute_indexes_and_offset(broadcast_shape) + indexes = self._compute_indexes(broadcast_shape) + offset = self.quantization_offset if offset is not None: bottleneck -= offset symbols = tf.cast(tf.round(bottleneck), tf.int32) @@ -356,7 +388,7 @@ def decompress(self, strings, broadcast_shape): symbols_shape = tf.concat( [batch_shape, broadcast_shape, self.prior_shape_tensor], 0) - indexes, offset = self._compute_indexes_and_offset(broadcast_shape) + indexes = self._compute_indexes(broadcast_shape) strings = tf.reshape(strings, [-1]) # Prevent tensors from bouncing back and forth between host and GPU. @@ -376,6 +408,7 @@ def loop_body(string): symbols = tf.reshape(symbols, symbols_shape) outputs = tf.cast(symbols, self.dtype) + offset = self.quantization_offset return outputs + offset if offset is not None else outputs def get_config(self): @@ -386,6 +419,10 @@ def get_config(self): """ config = super().get_config() config.update( - non_integer_offsets=self.quantization_offset is not None, + prior_shape=tuple(map(int, self.prior_shape)), + # Since the prior is never passed when using the `SavedModel` protocol, + # we can reuse this flag to indicate whether the offsets need to be + # loaded from a variable. + non_integer_offset=self.quantization_offset is not None, ) return config diff --git a/tensorflow_compression/python/entropy_models/continuous_batched_test.py b/tensorflow_compression/python/entropy_models/continuous_batched_test.py index 3824a85..6897572 100644 --- a/tensorflow_compression/python/entropy_models/continuous_batched_test.py +++ b/tensorflow_compression/python/entropy_models/continuous_batched_test.py @@ -14,13 +14,15 @@ # ============================================================================== """Tests of batched continuous entropy model.""" +from absl.testing import parameterized import tensorflow as tf import tensorflow_probability as tfp from tensorflow_compression.python.distributions import uniform_noise from tensorflow_compression.python.entropy_models.continuous_batched import ContinuousBatchedEntropyModel -class ContinuousBatchedEntropyModelTest(tf.test.TestCase): +class ContinuousBatchedEntropyModelTest(tf.test.TestCase, + parameterized.TestCase): def test_can_instantiate(self): noisy = uniform_noise.NoisyNormal(loc=0., scale=1.) @@ -107,45 +109,41 @@ def test_compression_consistent_with_quantization(self): x_decompressed = em.decompress(em.compress(x), [100]) self.assertAllEqual(x_decompressed, x_quantized) - def test_information_bounds(self): - # bits w/ `training=True` should be greater than bits w/ `training=False` - # because it is defined as an upper bound (albeit for infinite data). The - # actual length of the bit string should always be greater than - # bits w/ `training=False` because range coding is only asymptotically - # optimal, and because it operates on quantized probabilities. - for scale in 2 ** tf.linspace(-2., 7., 10): - noisy = uniform_noise.NoisyNormal(loc=0., scale=scale) - em = ContinuousBatchedEntropyModel(noisy, 1, compression=True) - x = noisy.base.sample([10000]) - _, bits_eval = em(x, training=False) - _, bits_training = em(x, training=True) - bits_compressed = 8 * len(em.compress(x).numpy()) - self.assertGreater(bits_training, .9975 * bits_eval) - self.assertGreater(bits_compressed, bits_eval) - - def test_low_entropy_bounds(self): - # For low entropy distributions, the training bound should be very loose, - # and the overhead of range coding manageable. - noisy = uniform_noise.NoisyNormal(loc=0., scale=.25) - em = ContinuousBatchedEntropyModel(noisy, 1, compression=True) - x = noisy.base.sample([10000]) + @parameterized.parameters(*[2. ** i for i in range(-2, 8)]) + def test_information_bounds(self, scale): + # Off-center prior to test quantization offset heuristic. Without it, it + # should be harder to achieve the bounds below. + prior = uniform_noise.NoisyNormal(loc=.5, scale=scale) + em = ContinuousBatchedEntropyModel(prior, coding_rank=1, compression=True) + x = prior.base.sample([1000000]) _, bits_eval = em(x, training=False) _, bits_training = em(x, training=True) bits_compressed = 8 * len(em.compress(x).numpy()) - self.assertAllClose(bits_training, bits_eval, atol=0, rtol=1.25) - self.assertAllClose(bits_compressed, bits_eval, atol=0, rtol=5e-3) - - def test_high_entropy_bounds(self): - # For high entropy distributions, the training bound should be very tight, - # and the overhead of range coding manageable. - noisy = uniform_noise.NoisyNormal(loc=0., scale=100.) - em = ContinuousBatchedEntropyModel(noisy, 1, compression=True) - x = noisy.base.sample([10000]) - _, bits_eval = em(x, training=False) - _, bits_training = em(x, training=True) - bits_compressed = 8 * len(em.compress(x).numpy()) - self.assertAllClose(bits_training, bits_eval, atol=0, rtol=5e-5) - self.assertAllClose(bits_compressed, bits_eval, atol=0, rtol=5e-3) + # Asymptotically, the entropy estimate with `training=True` is an upper + # bound on the entropy estimate with `training=False`. (With limited data, + # fluctuations are possible.) + with self.subTest("training bits > eval bits"): + # Sample size is too small for the bound to be asymptotic. Increasing it + # would make tests run too long. + self.assertGreater(bits_training, 0.999999 * bits_eval) + # Asymptotically, the length of the bit string should be greater than the + # entropy estimate with `training=False` because range coding is only + # asymptotically optimal, and because it operates on quantized + # probabilities. + with self.subTest("compressed bits > eval bits"): + self.assertGreater(bits_compressed, bits_eval) + # For low entropy distributions, the training bound can be very loose. + if scale <= .5: + with self.subTest("training bound loose"): + self.assertAllClose(bits_training, bits_eval, atol=0, rtol=1.25) + self.assertNotAllClose(bits_training, bits_eval, atol=0, rtol=1e-2) + # For high entropy distributions, the training bound should be tight. + if scale >= 64: + with self.subTest("training bound tight"): + self.assertAllClose(bits_training, bits_eval, atol=0, rtol=1e-5) + # The overhead of range coding should always be manageable. + with self.subTest("range coding overhead"): + self.assertAllClose(bits_compressed, bits_eval, atol=0, rtol=5e-3) def test_compression_works_after_serialization(self): noisy = uniform_noise.NoisyNormal(loc=.5, scale=8.) @@ -177,7 +175,7 @@ def test_compression_works_after_serialization_no_offset(self): def test_compression_works_in_tf_function(self): noisy = uniform_noise.NoisyNormal(loc=0, scale=5.) - sample = noisy.base.sample([100]) + samples = noisy.base.sample([100]) # Since tf.function traces each function twice, and only allows variable # creation in the first call, we need to have a stateful object in which we @@ -190,11 +188,11 @@ def compress(self, values): if not hasattr(self, "em"): self.em = ContinuousBatchedEntropyModel(noisy, 1, compression=True) compressed = self.em.compress(values) - decompressed = self.em.decompress(compressed, []) - return decompressed + return self.em.decompress(compressed, [100]) - values_eager = Compressor().compress(sample) - values_function = tf.function(Compressor().compress)(sample) + values_eager = Compressor().compress(samples) + values_function = tf.function(Compressor().compress)(samples) + self.assertAllClose(samples, values_eager, rtol=0., atol=.5) self.assertAllEqual(values_eager, values_function) def test_small_cdfs_for_dirac_prior_without_quantization_offset(self): @@ -220,7 +218,7 @@ def test_small_bitcost_for_dirac_prior(self): self.assertAllLessEqual(bits_estimate, 16) self.assertAllLessEqual(bitstring_bits, 16) # Quantization noise should be between -.5 and .5 - self.assertAllLessEqual(tf.abs(x - x_decoded), 0.5) + self.assertAllClose(x, x_decoded, rtol=0., atol=.5) if __name__ == "__main__": diff --git a/tensorflow_compression/python/entropy_models/continuous_indexed.py b/tensorflow_compression/python/entropy_models/continuous_indexed.py index 4578c72..f377034 100644 --- a/tensorflow_compression/python/entropy_models/continuous_indexed.py +++ b/tensorflow_compression/python/entropy_models/continuous_indexed.py @@ -15,10 +15,10 @@ """Indexed entropy model for continuous random variables.""" import tensorflow as tf -from tensorflow_compression.python.distributions import helpers from tensorflow_compression.python.entropy_models import continuous_base from tensorflow_compression.python.ops import gen_ops from tensorflow_compression.python.ops import math_ops +from tensorflow_compression.python.ops import round_ops __all__ = [ @@ -181,38 +181,53 @@ def __init__(self, tail_mass: Float. Approximate probability mass which is range encoded with less precision, by using a Golomb-like code. range_coder_precision: Integer. Precision passed to the range coding op. - dtype: `tf.dtypes.DType`. The data type of all floating-point - computations carried out in this class. + dtype: `tf.dtypes.DType`. Data type of this entropy model (i.e. dtype of + prior, decompressed values). laplace_tail_mass: Float. If positive, will augment the prior with a laplace mixture for training stability. (experimental) """ if coding_rank <= 0: - raise ValueError("coding_rank must be larger than 0.") + raise ValueError("`coding_rank` must be larger than 0.") if not callable(prior_fn): - raise TypeError("prior_fn must be a class or factory function.") + raise TypeError("`prior_fn` must be a class or factory function.") for name, fn in parameter_fns.items(): if not isinstance(name, str): - raise TypeError("parameter_fns must have string keys.") + raise TypeError("`parameter_fns` must have string keys.") if not callable(fn): - raise TypeError(f"parameter_fns['{name}'] must be callable.") - self._index_ranges = tuple(int(r) for r in index_ranges) - if not self.index_ranges: - raise ValueError("index_ranges must have at least one element.") - self._channel_axis = None if channel_axis is None else int(channel_axis) - if self.channel_axis is None and len(self.index_ranges) > 1: - raise ValueError("channel_axis can't be None for len(index_ranges) > 1.") - self._prior_fn = prior_fn - self._parameter_fns = dict(parameter_fns) + raise TypeError(f"`parameter_fns['{name}']` must be callable.") + super().__init__( - prior=self._make_range_coding_prior(self.index_ranges, dtype), coding_rank=coding_rank, compression=compression, stateless=stateless, expected_grads=expected_grads, tail_mass=tail_mass, range_coder_precision=range_coder_precision, + dtype=dtype, laplace_tail_mass=laplace_tail_mass, ) + self._index_ranges = tuple(int(r) for r in index_ranges) + if not self.index_ranges: + raise ValueError("`index_ranges` must have at least one element.") + self._channel_axis = None if channel_axis is None else int(channel_axis) + if self.channel_axis is None and len(self.index_ranges) > 1: + raise ValueError( + "`channel_axis` can't be `None` for `len(index_ranges) > 1`.") + self._prior_fn = prior_fn + self._parameter_fns = dict(parameter_fns) + + with self.name_scope: + if self.compression: + if self.channel_axis is None: + index_range, = index_ranges + indexes = tf.range(index_range, dtype=self.dtype) + else: + indexes = [tf.range(r, dtype=self.dtype) for r in index_ranges] + indexes = tf.meshgrid(*indexes, indexing="ij") + indexes = tf.stack(indexes, axis=self.channel_axis) + self._prior = self._make_prior(indexes) + cdf, cdf_offset, cdf_length = self._build_tables(self.prior) + self._init_compression(cdf, cdf_offset, cdf_length, None) @property def index_ranges(self): @@ -234,22 +249,14 @@ def channel_axis(self): """Position of channel axis in `indexes` tensor.""" return self._channel_axis - def _make_prior(self, indexes, dtype=None): - indexes = tf.cast(indexes, dtype or self.dtype) + def _make_prior(self, indexes): + indexes = tf.cast(indexes, self.dtype) parameters = {k: f(indexes) for k, f in self.parameter_fns.items()} - return self.prior_fn(**parameters) - - def _make_range_coding_prior(self, index_ranges, dtype): - """Instantiates the range coding prior.""" - dtype = tf.as_dtype(dtype) - if self.channel_axis is None: - index_range, = index_ranges - indexes = tf.range(index_range, dtype=dtype) - else: - indexes = [tf.range(r, dtype=dtype) for r in index_ranges] - indexes = tf.meshgrid(*indexes, indexing="ij") - indexes = tf.stack(indexes, axis=self.channel_axis) - return self._make_prior(indexes, dtype=dtype) + prior = self.prior_fn(**parameters) + assert prior.dtype == self.dtype + if prior.event_shape.rank: + raise ValueError("`prior` must be a (batch of) scalar distribution(s).") + return prior def _normalize_indexes(self, indexes): indexes = math_ops.lower_bound(indexes, 0) @@ -260,8 +267,7 @@ def _normalize_indexes(self, indexes): axes = [1] * indexes.shape.rank axes[self.channel_axis] = len(self.index_ranges) bounds = tf.reshape([s - 1 for s in self.index_ranges], axes) - indexes = math_ops.upper_bound(indexes, tf.cast(bounds, indexes.dtype)) - return indexes + return math_ops.upper_bound(indexes, tf.cast(bounds, indexes.dtype)) def _flatten_indexes(self, indexes): indexes = tf.cast(indexes, tf.int32) @@ -271,11 +277,6 @@ def _flatten_indexes(self, indexes): strides = tf.math.cumprod(self.index_ranges, exclusive=True, reverse=True) return tf.linalg.tensordot(indexes, strides, [[self.channel_axis], [0]]) - def _offset_from_indexes(self, indexes): - """Compute the quantization offset from the respective prior.""" - prior = self._make_prior(indexes) - return helpers.quantization_offset(prior) - @tf.Module.with_name_scope def __call__(self, bottleneck, indexes, training=True): """Perturbs a tensor with (quantization) noise and estimates rate. @@ -297,10 +298,7 @@ def __call__(self, bottleneck, indexes, training=True): innermost dimensions. """ indexes = self._normalize_indexes(indexes) - prior = self._make_prior(indexes) if training: - bottleneck_perturbed = bottleneck + tf.random.uniform( - tf.shape(bottleneck), minval=-.5, maxval=.5, dtype=bottleneck.dtype) def log_prob_fn(bottleneck_perturbed, indexes): # When using expected_grads=True, we will use a tf.custom_gradient on # this function. In this case, all non-Variable tensors that determine @@ -310,42 +308,35 @@ def log_prob_fn(bottleneck_perturbed, indexes): # reference here via a closure, we would get a `None` gradient for # `indexes`. prior = self._make_prior(indexes) - return self._log_prob_from_prior(prior, bottleneck_perturbed) + return self._log_prob(prior, bottleneck_perturbed) log_probs, bottleneck_perturbed = math_ops.perturb_and_apply( - log_prob_fn, bottleneck, indexes, expected_grads=self._expected_grads) + log_prob_fn, bottleneck, indexes, expected_grads=self.expected_grads) else: - offset = helpers.quantization_offset(prior) - bottleneck_perturbed = self._quantize(bottleneck, offset) - log_probs = self._log_prob_from_prior(prior, bottleneck_perturbed) + prior = self._make_prior(indexes) + bottleneck_perturbed = self.quantize(bottleneck) + log_probs = self._log_prob(prior, bottleneck_perturbed) axes = tuple(range(-self.coding_rank, 0)) bits = tf.reduce_sum(log_probs, axis=axes) / ( -tf.math.log(tf.constant(2, dtype=log_probs.dtype))) return bottleneck_perturbed, bits @tf.Module.with_name_scope - def quantize(self, bottleneck, indexes): + def quantize(self, bottleneck): """Quantizes a floating-point tensor. To use this entropy model as an information bottleneck during training, pass - a tensor through this function. The tensor is rounded to integer values - modulo a quantization offset, which depends on `indexes`. For instance, for - Gaussian distributions, the returned values are rounded to the location of - the mode of the distributions plus or minus an integer. + a tensor through this function. The tensor is rounded to integer values. The gradient of this rounding operation is overridden with the identity (straight-through gradient estimator). Args: bottleneck: `tf.Tensor` containing the data to be quantized. - indexes: `tf.Tensor` specifying the scalar distribution for each element - in `bottleneck`. See class docstring for examples. Returns: A `tf.Tensor` containing the quantized values. """ - indexes = self._normalize_indexes(indexes) - offset = self._offset_from_indexes(indexes) - return self._quantize(bottleneck, offset) + return round_ops.round_st(bottleneck) @tf.Module.with_name_scope def compress(self, bottleneck, indexes): @@ -379,8 +370,7 @@ def compress(self, bottleneck, indexes): flat_indexes = tf.reshape(flat_indexes, flat_shape) - offset = self._offset_from_indexes(indexes) - symbols = tf.cast(tf.round(bottleneck - offset), tf.int32) + symbols = tf.cast(tf.round(bottleneck), tf.int32) symbols = tf.reshape(symbols, flat_shape) # Prevent tensors from bouncing back and forth between host and GPU. @@ -442,8 +432,7 @@ def loop_body(args): loop_body, (strings, flat_indexes), dtype=tf.int32, name="decompress") symbols = tf.reshape(symbols, symbols_shape) - offset = self._offset_from_indexes(indexes) - return tf.cast(symbols, self.dtype) + offset + return tf.cast(symbols, self.dtype) def get_config(self): """Returns the configuration of the entropy model.""" @@ -463,9 +452,17 @@ class LocationScaleIndexedEntropyModel(ContinuousIndexedEntropyModel): This class is a common special case of `ContinuousIndexedEntropyModel`. The specified distribution is parameterized with `num_scales` values of scale parameters. An element-wise location parameter is handled by shifting the - distributions to zero. Note: this only works for shift-invariant - distributions, where the `loc` parameter really denotes a translation (i.e., - not for the log-normal distribution). + distributions to zero. + + This method is illustrated in Figure 10 of: + > "Nonlinear Transform Coding"
+ > J. Ballé, P.A. Chou, D. Minnen, S. Singh, N. Johnston, E. Agustsson, + > S.J. Hwang, G. Toderici
+ > https://doi.org/10.1109/JSTSP.2020.3034501 + + Note: this only works for shift-invariant `tfpd.Distribution` objects, where + the `loc` parameter really denotes a translation (i.e., not for the log-normal + distribution). """ def __init__(self, @@ -564,30 +561,25 @@ def __call__(self, bottleneck, scale_indexes, loc=None, training=True): innermost dimensions. """ if loc is None: - loc = 0.0 - bottleneck_centered = bottleneck - loc - bottleneck_centered_perturbed, bits = super().__call__( - bottleneck_centered, scale_indexes, training=training) - bottleneck_perturbed = bottleneck_centered_perturbed + loc - return bottleneck_perturbed, bits + return super().__call__(bottleneck, scale_indexes, training=training) + else: + bottleneck, bits = super().__call__( + bottleneck - loc, scale_indexes, training=training) + return bottleneck + loc, bits @tf.Module.with_name_scope - def quantize(self, bottleneck, scale_indexes, loc=None): + def quantize(self, bottleneck, loc=None): """Quantizes a floating-point tensor. To use this entropy model as an information bottleneck during training, pass a tensor through this function. The tensor is rounded to integer values - modulo a quantization offset, which depends on `indexes`. For instance, for - Gaussian distributions, the returned values are rounded to the location of - the mode of the distributions plus or minus an integer. + modulo the location parameters of the prior distribution given in `loc`. The gradient of this rounding operation is overridden with the identity (straight-through gradient estimator). Args: bottleneck: `tf.Tensor` containing the data to be quantized. - scale_indexes: `tf.Tensor` indexing the scale parameter for each element - in `bottleneck`. Must have the same shape as `bottleneck`. loc: `None` or `tf.Tensor`. If `None`, the location parameter for all elements is assumed to be zero. Otherwise, specifies the location parameter for each element in `bottleneck`. Must have the same shape as @@ -596,10 +588,7 @@ def quantize(self, bottleneck, scale_indexes, loc=None): Returns: A `tf.Tensor` containing the quantized values. """ - if loc is None: - return super().quantize(bottleneck, scale_indexes) - else: - return super().quantize(bottleneck - loc, scale_indexes) + loc + return round_ops.round_st(bottleneck, loc) @tf.Module.with_name_scope def compress(self, bottleneck, scale_indexes, loc=None): diff --git a/tensorflow_compression/python/entropy_models/continuous_indexed_test.py b/tensorflow_compression/python/entropy_models/continuous_indexed_test.py index e36cdff..07b5c73 100644 --- a/tensorflow_compression/python/entropy_models/continuous_indexed_test.py +++ b/tensorflow_compression/python/entropy_models/continuous_indexed_test.py @@ -14,86 +14,221 @@ # ============================================================================== """Tests of indexed continuous entropy model.""" +from absl.testing import parameterized import tensorflow as tf +import tensorflow_probability as tfp from tensorflow_compression.python.distributions import uniform_noise from tensorflow_compression.python.entropy_models import continuous_indexed -# TODO(jonycgn): add further unit tests. +class ContinuousIndexedEntropyModelTest(tf.test.TestCase, + parameterized.TestCase): + def get_model(self, prior_fn=uniform_noise.NoisyLogisticMixture, + coding_rank=1, scale=1., **kwargs): + return continuous_indexed.ContinuousIndexedEntropyModel( + prior_fn, + (2, 3, 5), + dict( + loc=lambda i: i[..., :2] - [0., 1.5], + scale=lambda _: scale, + weight=lambda i: tf.nn.softmax((i[..., 2:] - 2.) * [-1., 1.]), + ), + coding_rank, + **kwargs) -class ContinuousIndexedEntropyModelTest(tf.test.TestCase): + def get_samples(self, shape, scale=2., dtype=tf.float32): + # This produces samples from a smoothed Laplacian with the requested scale. + # They're not really samples from the prior, but approximately cover it, + # and have the same tail behavior. + x = tf.random.stateless_uniform( + shape, minval=0., maxval=1., seed=(0, 1), dtype=dtype) + s = tf.random.stateless_uniform( + shape, minval=-1., maxval=1., seed=(1, 2), dtype=dtype) + u = tf.random.stateless_uniform( + shape, minval=-.5, maxval=.5, seed=(3, 4), dtype=dtype) + x = (tf.math.log(x) * tf.math.sign(s) + u) * scale + indexes = tf.random.stateless_uniform( + tuple(shape) + (3,), minval=-.4, maxval=(2.4, 3.4, 5), seed=(5, 6), + dtype=tf.float32) + return x, indexes - def test_can_instantiate_one_dimensional(self): - em = continuous_indexed.ContinuousIndexedEntropyModel( - uniform_noise.NoisyNormal, (64,), - dict(loc=lambda _: 0, scale=lambda i: tf.exp(i / 8 - 5)), 1, - compression=True, channel_axis=None) - self.assertIsInstance(em.prior, uniform_noise.NoisyNormal) + def test_can_instantiate_and_compress(self): + em = self.get_model(compression=True) + self.assertIsInstance(em.prior, uniform_noise.NoisyLogisticMixture) self.assertEqual(em.coding_rank, 1) + self.assertEqual(em.channel_axis, -1) self.assertEqual(em.tail_mass, 2**-8) self.assertEqual(em.range_coder_precision, 12) self.assertEqual(em.dtype, tf.float32) - x = tf.random.stateless_normal((3, 8, 16), seed=(0, 0)) - indexes = tf.cast(64 * tf.random.stateless_uniform((3, 8, 16), seed=(0, 0)), - tf.int32) - em(x, indexes) + x, indexes = self.get_samples((2, 5)) + x_tilde, bits = em(x, indexes) bitstring = em.compress(x, indexes) x_hat = em.decompress(bitstring, indexes) - self.assertAllLess(x - x_hat, 0.5) - self.assertAllGreater(x - x_hat, -0.5) + self.assertAllClose(x, x_hat, rtol=0, atol=.5) + self.assertAllClose(x, x_tilde, rtol=0, atol=.5) + self.assertEqual(bits.shape, (2,)) + self.assertAllGreaterEqual(bits, 0.) - def test_can_instantiate_and_compress_n_dimensional(self): - em = continuous_indexed.ContinuousIndexedEntropyModel( - uniform_noise.NoisyLogisticMixture, - (10, 10, 5), - dict( - loc=lambda i: i[..., 0:2] - 5, - scale=lambda _: 1, - weight=lambda i: tf.nn.softmax((i[..., 2:3] - 2) * [-1, 1]), - ), - 1, - compression=True - ) + def test_can_instantiate_and_compress_statelessly(self): + em = self.get_model(compression=True, stateless=True, dtype=tf.float64) + self.assertEqual(em.compression, True) + self.assertEqual(em.stateless, True) self.assertIsInstance(em.prior, uniform_noise.NoisyLogisticMixture) self.assertEqual(em.coding_rank, 1) - self.assertEqual(em.channel_axis, -1) self.assertEqual(em.tail_mass, 2**-8) self.assertEqual(em.range_coder_precision, 12) - self.assertEqual(em.dtype, tf.float32) - x = tf.random.stateless_normal((3, 8, 16), seed=(0, 0)) - indexes = tf.cast( - 10 * tf.random.stateless_uniform((3, 8, 16, 3), seed=(0, 0)), tf.int32) - em(x, indexes) + self.assertEqual(em.dtype, tf.float64) + x, indexes = self.get_samples((7,), dtype=tf.float64) + x_tilde, bits = em(x, indexes) bitstring = em.compress(x, indexes) x_hat = em.decompress(bitstring, indexes) - self.assertAllLess(x - x_hat, 0.5) - self.assertAllGreater(x - x_hat, -0.5) + self.assertAllClose(x, x_hat, rtol=0, atol=.5) + self.assertAllClose(x, x_tilde, rtol=0, atol=.5) + self.assertEqual(bits.shape, ()) + self.assertAllGreaterEqual(bits, 0.) + + def test_indexes_are_clipped_correctly(self): + em = self.get_model(compression=True, coding_rank=2) + x, indexes = self.get_samples((7, 23)) + x_float_idx = em.decompress(em.compress(x, indexes), indexes) + indexes = tf.cast(tf.round(indexes), tf.int32) + x_int_idx = em.decompress(em.compress(x, indexes), indexes) + self.assertAllEqual(x_float_idx, x_int_idx) + self.assertAllClose(x, x_float_idx, rtol=0, atol=.5) + + def test_requires_scalar_distributions(self): + def prior_fn(**_): + return uniform_noise.UniformNoiseAdapter( + tfp.distributions.MultivariateNormalDiag( + loc=[-3, .2], scale_diag=[1, 2])) + with self.assertRaises(ValueError): + self.get_model(prior_fn=prior_fn, compression=True) + + def test_quantizes_to_integers(self): + em = self.get_model() + x = tf.range(-20., 20.) + x_perturbed = x + tf.random.uniform(x.shape, -.49, .49) + x_quantized = em.quantize(x_perturbed) + self.assertAllEqual(x, x_quantized) + + def test_gradients_are_straight_through(self): + em = self.get_model() + x = tf.range(-20., 20.) + x_perturbed = x + tf.random.uniform(x.shape, -.49, .49) + with tf.GradientTape() as tape: + tape.watch(x_perturbed) + x_quantized = em.quantize(x_perturbed) + gradients = tape.gradient(x_quantized, x_perturbed) + self.assertAllEqual(gradients, tf.ones_like(gradients)) + + def test_default_kwargs_throw_error_on_compression(self): + em = self.get_model() + x, indexes = self.get_samples((5,)) + with self.assertRaises(RuntimeError): + em.compress(x, indexes) + s = tf.zeros((), dtype=tf.string) + with self.assertRaises(RuntimeError): + em.decompress(s, indexes) + + def test_compression_consistent_with_quantization(self): + em = self.get_model(compression=True) + x, indexes = self.get_samples((100,)) + x_quantized = em.quantize(x) + x_decompressed = em.decompress(em.compress(x, indexes), indexes) + self.assertAllEqual(x_decompressed, x_quantized) + + @parameterized.parameters(*[2. ** i for i in range(-2, 8)]) + def test_information_bounds(self, scale): + em = self.get_model(scale=scale, compression=True) + x, indexes = self.get_samples([200000], scale=2. * scale) + _, bits_eval = em(x, indexes, training=False) + _, bits_training = em(x, indexes, training=True) + bits_compressed = 8 * len(em.compress(x, indexes).numpy()) + # Asymptotically, the entropy estimate with `training=True` is an upper + # bound on the entropy estimate with `training=False`. (With limited data, + # fluctuations are possible.) + with self.subTest("training bits > eval bits"): + # Sample size is too small for the bound to be asymptotic. Increasing it + # would make tests run too long. + self.assertGreater(bits_training, 0.99999 * bits_eval) + # Asymptotically, the length of the bit string should be greater than the + # entropy estimate with `training=False` because range coding is only + # asymptotically optimal, and because it operates on quantized + # probabilities. + with self.subTest("compressed bits > eval bits"): + self.assertGreater(bits_compressed, bits_eval) + # For low entropy distributions, the training bound can be very loose. + if scale <= .5: + with self.subTest("training bound loose"): + self.assertAllClose(bits_training, bits_eval, atol=0, rtol=1e-2) + self.assertNotAllClose(bits_training, bits_eval, atol=0, rtol=1e-4) + # For high entropy distributions, the training bound should be tight. + if scale >= 64: + with self.subTest("training bound tight"): + self.assertAllClose(bits_training, bits_eval, atol=0, rtol=1e-5) + # The overhead of range coding should always be manageable. + with self.subTest("range coding overhead"): + self.assertAllClose(bits_compressed, bits_eval, atol=0, rtol=4e-2) + + def test_compression_works_in_tf_function(self): + samples, indexes = self.get_samples((100,)) + + # Since tf.function traces each function twice, and only allows variable + # creation in the first call, we need to have a stateful object in which we + # create the entropy model only the first time the function is called, and + # store it for the second time. + + # We need this since `self` below shadows the test object. + get_model = self.get_model + + class Compressor: + + def compress(self, values, indexes): + if not hasattr(self, "em"): + self.em = get_model(compression=True) + compressed = self.em.compress(values, indexes) + return self.em.decompress(compressed, indexes) + + values_eager = Compressor().compress(samples, indexes) + values_function = tf.function(Compressor().compress)(samples, indexes) + self.assertAllClose(samples, values_eager, rtol=0., atol=.5) + self.assertAllEqual(values_eager, values_function) class LocationScaleIndexedEntropyModelTest(tf.test.TestCase): - def test_can_instantiate_and_compress(self): - em = continuous_indexed.LocationScaleIndexedEntropyModel( - uniform_noise.NoisyNormal, + def get_model(self, prior_fn=uniform_noise.NoisyNormal, + coding_rank=1, **kwargs): + return continuous_indexed.LocationScaleIndexedEntropyModel( + prior_fn, 64, - lambda i: tf.exp(i / 8 - 5), - 1, - compression=True) + lambda i: tf.exp(i / 8. - 5.), + coding_rank, + **kwargs) + + def get_samples(self, shape): + x = tf.random.stateless_normal(shape, stddev=5., seed=(0, 1)) + indexes = tf.random.stateless_uniform( + shape, minval=-.4, maxval=64.4, seed=(0, 0), dtype=tf.float32) + loc = tf.random.stateless_normal(shape, stddev=5., seed=(2, 3)) + return x, indexes, loc + + def test_can_instantiate_and_compress(self): + em = self.get_model(compression=True) self.assertIsInstance(em.prior, uniform_noise.NoisyNormal) self.assertEqual(em.coding_rank, 1) self.assertEqual(em.tail_mass, 2**-8) self.assertEqual(em.range_coder_precision, 12) self.assertEqual(em.dtype, tf.float32) - x = tf.random.stateless_normal((3, 8, 16), seed=(0, 0)) - indexes = tf.cast(10 * tf.random.stateless_uniform((3, 8, 16), seed=(0, 0)), - tf.int32) - loc = tf.random.stateless_uniform((3, 8, 16), seed=(0, 0)) - em(x, indexes, loc=loc) + x, indexes, loc = self.get_samples((7, 4)) + x_tilde, bits = em(x, indexes, loc=loc) bitstring = em.compress(x, indexes, loc=loc) x_hat = em.decompress(bitstring, indexes, loc=loc) - self.assertAllLessEqual(x - x_hat, 0.5) - self.assertAllGreaterEqual(x - x_hat, -0.5) + self.assertAllClose(x, x_hat, rtol=0, atol=.5) + self.assertAllClose(x, x_tilde, rtol=0, atol=.5) + self.assertEqual(bits.shape, (7,)) + self.assertAllGreaterEqual(bits, 0.) if __name__ == "__main__": diff --git a/tensorflow_compression/python/entropy_models/universal.py b/tensorflow_compression/python/entropy_models/universal.py index b5ab696..a87d974 100644 --- a/tensorflow_compression/python/entropy_models/universal.py +++ b/tensorflow_compression/python/entropy_models/universal.py @@ -14,8 +14,8 @@ import functools import tensorflow as tf -from tensorflow_compression.python.entropy_models import continuous_batched -from tensorflow_compression.python.entropy_models import continuous_indexed +from tensorflow_compression.python.entropy_models import continuous_base +from tensorflow_compression.python.ops import gen_ops from tensorflow_compression.python.ops import math_ops @@ -61,8 +61,7 @@ def _range_coding_offsets(num_noise_levels, prior_shape, dtype=tf.float32): return offset -class UniversalBatchedEntropyModel( - continuous_batched.ContinuousBatchedEntropyModel): +class UniversalBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase): """Batched entropy model model which implements Universal Quantization. In contrast to the base class, which uses rounding for quantization, here @@ -118,49 +117,52 @@ def __init__(self, allows it to be constructed within a `tf.function` body. If `compression=False`, then `stateless=True` is implied and the provided value is ignored. - - Raises: - RuntimeError: when attempting to instantiate an entropy model with - `compression=True` and not in eager execution mode. """ - # This attribute is used in methods we override in this class which - # are used during used during super().__init__(...), so we set it first. - self._num_noise_levels = num_noise_levels + if prior.event_shape.rank: + raise ValueError("`prior` must be a (batch of) scalar distribution(s).") super().__init__( - prior=prior, coding_rank=coding_rank, compression=compression, - laplace_tail_mass=laplace_tail_mass, + stateless=stateless, expected_grads=expected_grads, tail_mass=tail_mass, range_coder_precision=range_coder_precision, - stateless=stateless) + dtype=prior.dtype, + laplace_tail_mass=laplace_tail_mass, + ) + self._prior = prior + self._num_noise_levels = num_noise_levels + if self.coding_rank < self.prior_shape.rank: + raise ValueError("`coding_rank` can't be smaller than `prior_shape`.") + + with self.name_scope: + if self.compression: + offset = _range_coding_offsets( + self._num_noise_levels, self.prior_shape, self.dtype) + cdf, cdf_offset, cdf_length = self._build_tables( + self.prior, + offset=offset, + context_shape=(self._num_noise_levels,) + self.prior_shape) + self._init_compression(cdf, cdf_offset, cdf_length, None) @property - def context_shape(self): - """See base class.""" - return (self._num_noise_levels,) + self.prior_shape - - def _cache_quantization_offset(self): - """See base class.""" - # Universal Quantization derives offsets from a pseudorandom source. - self._quantization_offset = None + def prior_shape(self): + """Batch shape of `prior` (dimensions which are not assumed i.i.d.).""" + return tf.TensorShape(self.prior.batch_shape) - def _offset_from_prior(self, prior): - """See base class.""" - return _range_coding_offsets(self._num_noise_levels, self.prior_shape, - self.dtype) + @property + def prior_shape_tensor(self): + """Batch shape of `prior` as a `Tensor`.""" + return tf.constant(self.prior_shape.as_list(), dtype=tf.int32) def _compute_indexes_and_offset(self, broadcast_shape): - """See base class.""" + """Returns the indexes for range coding and the quantization offset.""" prior_size = int(self.prior_shape.num_elements()) # Create index for each dimension in prior_shape. indexes = tf.range(prior_size, dtype=tf.int32) indexes = tf.broadcast_to( indexes, tf.concat((broadcast_shape, tf.shape(indexes)), axis=0)) - # Add channel dimension. - channel_axis = -1 indexes = indexes[..., None] # Add in offset indexes. @@ -172,17 +174,13 @@ def _compute_indexes_and_offset(self, broadcast_shape): # Flatten prior + offset indexes. index_ranges = [self._num_noise_levels, prior_size] strides = tf.math.cumprod(index_ranges, exclusive=True, reverse=True) - indexes = tf.linalg.tensordot(indexes, strides, [[channel_axis], [0]]) + indexes = tf.linalg.tensordot(indexes, strides, [[-1], [0]]) # Now bring to full shape. full_shape = tf.concat([broadcast_shape, self.prior_shape_tensor], 0) indexes = tf.reshape(indexes, full_shape) offset = tf.reshape(offset, full_shape) return indexes, offset - @tf.Module.with_name_scope - def quantize(self, bottleneck, indexes=None): - raise NotImplementedError() - @tf.Module.with_name_scope def __call__(self, bottleneck, training=True): """Perturbs a tensor with additive uniform noise and estimates bitcost. @@ -202,8 +200,7 @@ def __call__(self, bottleneck, training=True): and `bits` is the bitcost of transmitting such a sample having the same shape as `bottleneck` without the `self.coding_rank` innermost dimensions. """ - - log_prob_fn = functools.partial(self._log_prob_from_prior, self.prior) + log_prob_fn = functools.partial(self._log_prob, self.prior) if training: log_probs, bottleneck_perturbed = math_ops.perturb_and_apply( log_prob_fn, bottleneck, expected_grads=self._expected_grads) @@ -224,13 +221,109 @@ def __call__(self, bottleneck, training=True): -tf.math.log(tf.constant(2., dtype=log_probs.dtype))) return bottleneck_perturbed, bits + @tf.Module.with_name_scope + def compress(self, bottleneck): + """Compresses a floating-point tensor. + + Compresses the tensor to bit strings. `bottleneck` is first quantized + as in `quantize()`, and then compressed using the probability tables in + `self.cdf` derived from `self.prior`. The quantized tensor can later be + recovered by calling `decompress()`. + + The innermost `self.coding_rank` dimensions are treated as one coding unit, + i.e. are compressed into one string each. Any additional dimensions to the + left are treated as batch dimensions. + + Args: + bottleneck: `tf.Tensor` containing the data to be compressed. Must have at + least `self.coding_rank` dimensions, and the innermost dimensions must + be broadcastable to `self.prior_shape`. + + Returns: + A `tf.Tensor` having the same shape as `bottleneck` without the + `self.coding_rank` innermost dimensions, containing a string for each + coding unit. + """ + input_shape = tf.shape(bottleneck) + input_rank = tf.shape(input_shape)[0] + batch_shape, coding_shape = tf.split( + input_shape, [input_rank - self.coding_rank, self.coding_rank]) + broadcast_shape = coding_shape[ + :self.coding_rank - len(self.prior_shape)] + + indexes, offset = self._compute_indexes_and_offset(broadcast_shape) + bottleneck -= offset + symbols = tf.cast(tf.round(bottleneck), tf.int32) + symbols = tf.reshape(symbols, tf.concat([[-1], coding_shape], 0)) + + # Prevent tensors from bouncing back and forth between host and GPU. + with tf.device("/cpu:0"): + cdf = self.cdf + cdf_length = self.cdf_length + cdf_offset = self.cdf_offset + def loop_body(symbols): + return gen_ops.unbounded_index_range_encode( + symbols, indexes, cdf, cdf_length, cdf_offset, + precision=self.range_coder_precision, + overflow_width=4, debug_level=1) + + # TODO(jonycgn,ssjhv): Consider switching to Python control flow. + strings = tf.map_fn( + loop_body, symbols, dtype=tf.string, name="compress") + + return tf.reshape(strings, batch_shape) + + @tf.Module.with_name_scope + def decompress(self, strings, broadcast_shape): + """Decompresses a tensor. + + Reconstructs the quantized tensor from bit strings produced by `compress()`. + It is necessary to provide a part of the output shape in `broadcast_shape`. + + Args: + strings: `tf.Tensor` containing the compressed bit strings. + broadcast_shape: Iterable of ints. The part of the output tensor shape + between the shape of `strings` on the left and `self.prior_shape` on the + right. This must match the shape of the input to `compress()`. + + Returns: + A `tf.Tensor` of shape `strings.shape + broadcast_shape + + self.prior_shape`. + """ + strings = tf.convert_to_tensor(strings, dtype=tf.string) + broadcast_shape = tf.convert_to_tensor(broadcast_shape, dtype=tf.int32) + batch_shape = tf.shape(strings) + symbols_shape = tf.concat( + [batch_shape, broadcast_shape, self.prior_shape_tensor], 0) + + indexes, offset = self._compute_indexes_and_offset(broadcast_shape) + strings = tf.reshape(strings, [-1]) + + # Prevent tensors from bouncing back and forth between host and GPU. + with tf.device("/cpu:0"): + cdf = self.cdf + cdf_length = self.cdf_length + cdf_offset = self.cdf_offset + def loop_body(string): + return gen_ops.unbounded_index_range_decode( + string, indexes, cdf, cdf_length, cdf_offset, + precision=self.range_coder_precision, + overflow_width=4, debug_level=1) + + # TODO(jonycgn,ssjhv): Consider switching to Python control flow. + symbols = tf.map_fn( + loop_body, strings, dtype=tf.int32, name="decompress") + + symbols = tf.reshape(symbols, symbols_shape) + outputs = tf.cast(symbols, self.dtype) + return outputs + offset + def get_config(self): # TODO(relational): Implement this when we need serialization. raise NotImplementedError() -class UniversalIndexedEntropyModel( - continuous_indexed.ContinuousIndexedEntropyModel): +class UniversalIndexedEntropyModel(continuous_base.ContinuousEntropyModelBase): """Indexed entropy model model which implements Universal Quantization. In contrast to the base class, which uses rounding for quantization, here @@ -241,7 +334,6 @@ class UniversalIndexedEntropyModel( > "Universally Quantized Neural Compression"
> Eirikur Agustsson & Lucas Theis
> https://arxiv.org/abs/2006.09952 - """ def __init__(self, @@ -297,34 +389,69 @@ def __init__(self, rather than `Variable`s. num_noise_levels: Integer. The number of levels used to quantize the uniform noise. - - Raises: - RuntimeError: when attempting to instantiate an entropy model with - `compression=True` and not in eager execution mode. """ - # Add extra indexes for noise levels. - index_ranges_with_offsets = tuple([num_noise_levels] + - [int(r) for r in index_ranges]) + if coding_rank <= 0: + raise ValueError("`coding_rank` must be larger than 0.") + if not callable(prior_fn): + raise TypeError("`prior_fn` must be a class or factory function.") + for name, fn in parameter_fns.items(): + if not isinstance(name, str): + raise TypeError("`parameter_fns` must have string keys.") + if not callable(fn): + raise TypeError(f"`parameter_fns['{name}']` must be callable.") - # This attribute is used in methods we override in this class which - # are used during used during super().__init__(...), so we set it first. - self._num_noise_levels = num_noise_levels - - # We only support channel axis at the last dimension. - channel_axis = -1 super().__init__( - prior_fn=prior_fn, - index_ranges=index_ranges_with_offsets, - parameter_fns=parameter_fns, coding_rank=coding_rank, compression=compression, - channel_axis=channel_axis, - dtype=dtype, - tail_mass=tail_mass, - laplace_tail_mass=laplace_tail_mass, + stateless=stateless, expected_grads=expected_grads, + tail_mass=tail_mass, range_coder_precision=range_coder_precision, - stateless=stateless) + dtype=dtype, + laplace_tail_mass=laplace_tail_mass, + ) + # Add extra indexes for noise levels. + self._index_ranges = tuple( + [num_noise_levels] + [int(r) for r in index_ranges]) + if not self.index_ranges: + raise ValueError("`index_ranges` must have at least one element.") + self._prior_fn = prior_fn + self._parameter_fns = dict(parameter_fns) + self._num_noise_levels = num_noise_levels + + with self.name_scope: + if self.compression: + index_ranges = self.index_ranges_without_offsets + indexes = [tf.range(r, dtype=self.dtype) for r in index_ranges] + indexes = tf.meshgrid(*indexes, indexing="ij") + indexes = tf.stack(indexes, axis=-1) + self._prior = self._make_prior(indexes) + cdf, cdf_offset, cdf_length = self._build_tables( + self.prior, + offset=_range_coding_offsets( + self._num_noise_levels, self.prior_shape, self.dtype), + context_shape=self.context_shape) + self._init_compression(cdf, cdf_offset, cdf_length, None) + + @property + def index_ranges(self): + """Upper bound(s) on values allowed in `indexes` tensor.""" + return self._index_ranges + + @property + def parameter_fns(self): + """Functions mapping `indexes` to each distribution parameter.""" + return self._parameter_fns + + @property + def prior_fn(self): + """Class or factory function returning a `Distribution` object.""" + return self._prior_fn + + @property + def prior_shape(self): + """Batch shape of `prior` (dimensions which are not assumed i.i.d.).""" + return tf.TensorShape(self.prior.batch_shape) @property def context_shape(self): @@ -336,6 +463,16 @@ def index_ranges_without_offsets(self): """Upper bound(s) on values allowed in `indexes` , excluding offsets.""" return _index_ranges_without_offsets(self.index_ranges) + def _make_prior(self, indexes): + indexes = tf.cast(indexes, self.dtype) + parameters = {k: f(indexes) for k, f in self.parameter_fns.items()} + return self.prior_fn(**parameters) + + def _flatten_indexes(self, indexes): + indexes = tf.cast(indexes, tf.int32) + strides = tf.math.cumprod(self.index_ranges, exclusive=True, reverse=True) + return tf.linalg.tensordot(indexes, strides, [[-1], [0]]) + def _normalize_indexes(self, indexes): """See base class.""" num_indexes = indexes.shape[-1] # Last dim of `indexes` should be static. @@ -348,30 +485,17 @@ def _normalize_indexes(self, indexes): assert num_indexes == len(index_ranges) indexes = math_ops.lower_bound(indexes, 0) axes = [1] * indexes.shape.rank - axes[self.channel_axis] = len(index_ranges) + axes[-1] = len(index_ranges) bounds = tf.reshape([s - 1 for s in index_ranges], axes) return math_ops.upper_bound(indexes, tf.cast(bounds, indexes.dtype)) def _offset_from_indexes(self, indexes_with_offsets): - """Computes the offset for universal quantization (overrides base class).""" + """Computes the offset for universal quantization.""" offset_indexes = indexes_with_offsets[..., 0] offset = _offset_indexes_to_offset( offset_indexes, self._num_noise_levels, dtype=self.dtype) return offset - def _make_range_coding_prior(self, index_ranges, dtype): - """Instantiates the range coding prior.""" - return super()._make_range_coding_prior( - _index_ranges_without_offsets(index_ranges), dtype) - - def _offset_from_prior(self, prior): - return _range_coding_offsets(self._num_noise_levels, self.prior_shape, - self.dtype) - - @tf.Module.with_name_scope - def quantize(self, bottleneck, indexes=None): - raise NotImplementedError() - @tf.Module.with_name_scope def __call__(self, bottleneck, indexes, training=True): """Perturbs a tensor with additive uniform noise and estimates bitcost. @@ -392,7 +516,6 @@ def __call__(self, bottleneck, indexes, training=True): and `bits` is the bitcost of transmitting such a sample having the same shape as `bottleneck` without the `self.coding_rank` innermost dimensions. """ - indexes = self._normalize_indexes(indexes) if training: # Here we compute `h(bottleneck + noise)`. @@ -405,7 +528,7 @@ def log_prob_fn(bottleneck_perturbed, indexes): # reference here via a closure, we would get a `None` gradient for # `indexes`. prior = self._make_prior(indexes) - return self._log_prob_from_prior(prior, bottleneck_perturbed) + return self._log_prob(prior, bottleneck_perturbed) log_probs, bottleneck_perturbed = math_ops.perturb_and_apply( log_prob_fn, bottleneck, indexes, expected_grads=self._expected_grads) @@ -417,7 +540,7 @@ def log_prob_fn(bottleneck_perturbed, indexes): self._num_noise_levels, self.dtype) symbols = tf.round(bottleneck - offset) bottleneck_perturbed = symbols + offset - log_probs = self._log_prob_from_prior(prior, bottleneck_perturbed) + log_probs = self._log_prob(prior, bottleneck_perturbed) axes = tuple(range(-self.coding_rank, 0)) bits = tf.reduce_sum(log_probs, axis=axes) / ( @@ -426,15 +549,103 @@ def log_prob_fn(bottleneck_perturbed, indexes): @tf.Module.with_name_scope def compress(self, bottleneck, indexes): - """See base class.""" - indexes_with_offset = _add_offset_indexes(indexes, self._num_noise_levels) - return super().compress(bottleneck, indexes_with_offset) + """Compresses a floating-point tensor. + + Compresses the tensor to bit strings. `bottleneck` is first quantized + as in `quantize()`, and then compressed using the probability tables derived + from `indexes`. The quantized tensor can later be recovered by calling + `decompress()`. + + The innermost `self.coding_rank` dimensions are treated as one coding unit, + i.e. are compressed into one string each. Any additional dimensions to the + left are treated as batch dimensions. + + Args: + bottleneck: `tf.Tensor` containing the data to be compressed. + indexes: `tf.Tensor` specifying the scalar distribution for each element + in `bottleneck`. See class docstring for examples. + + Returns: + A `tf.Tensor` having the same shape as `bottleneck` without the + `self.coding_rank` innermost dimensions, containing a string for each + coding unit. + """ + indexes = _add_offset_indexes(indexes, self._num_noise_levels) + indexes = self._normalize_indexes(indexes) + flat_indexes = self._flatten_indexes(indexes) + + symbols_shape = tf.shape(flat_indexes) + batch_shape = symbols_shape[:-self.coding_rank] + flat_shape = tf.concat([[-1], symbols_shape[-self.coding_rank:]], 0) + + flat_indexes = tf.reshape(flat_indexes, flat_shape) + + offset = self._offset_from_indexes(indexes) + symbols = tf.cast(tf.round(bottleneck - offset), tf.int32) + symbols = tf.reshape(symbols, flat_shape) + + # Prevent tensors from bouncing back and forth between host and GPU. + with tf.device("/cpu:0"): + cdf = self.cdf + cdf_length = self.cdf_length + cdf_offset = self.cdf_offset + def loop_body(args): + return gen_ops.unbounded_index_range_encode( + args[0], args[1], cdf, cdf_length, cdf_offset, + precision=self.range_coder_precision, + overflow_width=4, debug_level=1) + + # TODO(jonycgn,ssjhv): Consider switching to Python control flow. + strings = tf.map_fn( + loop_body, (symbols, flat_indexes), dtype=tf.string, name="compress") + + strings = tf.reshape(strings, batch_shape) + return strings @tf.Module.with_name_scope def decompress(self, strings, indexes): - """See base class.""" - indexes_with_offset = _add_offset_indexes(indexes, self._num_noise_levels) - return super().decompress(strings, indexes_with_offset) + """Decompresses a tensor. + + Reconstructs the quantized tensor from bit strings produced by `compress()`. + + Args: + strings: `tf.Tensor` containing the compressed bit strings. + indexes: `tf.Tensor` specifying the scalar distribution for each output + element. See class docstring for examples. + + Returns: + A `tf.Tensor` of the same shape as `indexes` (without the optional channel + dimension). + """ + indexes = _add_offset_indexes(indexes, self._num_noise_levels) + indexes = self._normalize_indexes(indexes) + flat_indexes = self._flatten_indexes(indexes) + + symbols_shape = tf.shape(flat_indexes) + flat_shape = tf.concat([[-1], symbols_shape[-self.coding_rank:]], 0) + + flat_indexes = tf.reshape(flat_indexes, flat_shape) + + strings = tf.reshape(strings, [-1]) + + # Prevent tensors from bouncing back and forth between host and GPU. + with tf.device("/cpu:0"): + cdf = self.cdf + cdf_length = self.cdf_length + cdf_offset = self.cdf_offset + def loop_body(args): + return gen_ops.unbounded_index_range_decode( + args[0], args[1], cdf, cdf_length, cdf_offset, + precision=self.range_coder_precision, + overflow_width=4, debug_level=1) + + # TODO(jonycgn,ssjhv): Consider switching to Python control flow. + symbols = tf.map_fn( + loop_body, (strings, flat_indexes), dtype=tf.int32, name="decompress") + + symbols = tf.reshape(symbols, symbols_shape) + offset = self._offset_from_indexes(indexes) + return tf.cast(symbols, self.dtype) + offset def get_config(self): # TODO(relational): Implement this when we need serialization. diff --git a/tensorflow_compression/python/entropy_models/universal_test.py b/tensorflow_compression/python/entropy_models/universal_test.py index 3e0b46e..281da16 100644 --- a/tensorflow_compression/python/entropy_models/universal_test.py +++ b/tensorflow_compression/python/entropy_models/universal_test.py @@ -161,7 +161,6 @@ def test_can_instantiate_n_dimensional(self): coding_rank=1, ) self.assertEqual(em.coding_rank, 1) - self.assertEqual(em.channel_axis, -1) self.assertEqual(em._laplace_tail_mass, 0.0) self.assertEqual(em.tail_mass, 2**-8) self.assertEqual(em.range_coder_precision, 12) @@ -235,7 +234,6 @@ def test_bitstring_length_matches_estimates(self, training): coding_rank=1, compression=True) self.assertEqual(em.coding_rank, 1) - self.assertEqual(em.channel_axis, -1) self.assertEqual(em._laplace_tail_mass, 0.0) self.assertEqual(em.tail_mass, 2**-8) self.assertEqual(em.range_coder_precision, 12) diff --git a/tensorflow_compression/python/layers/BUILD b/tensorflow_compression/python/layers/BUILD index 4f30f59..f765e05 100644 --- a/tensorflow_compression/python/layers/BUILD +++ b/tensorflow_compression/python/layers/BUILD @@ -75,7 +75,7 @@ py_library( name = "soft_round", srcs = ["soft_round.py"], srcs_version = "PY3", - deps = ["//tensorflow_compression/python/ops:soft_round_ops"], + deps = ["//tensorflow_compression/python/ops:round_ops"], ) py_test( @@ -84,7 +84,7 @@ py_test( python_version = "PY3", deps = [ ":soft_round", - "//tensorflow_compression/python/ops:soft_round_ops", + "//tensorflow_compression/python/ops:round_ops", ], ) diff --git a/tensorflow_compression/python/layers/soft_round.py b/tensorflow_compression/python/layers/soft_round.py index 3bc07d7..f0ba76d 100644 --- a/tensorflow_compression/python/layers/soft_round.py +++ b/tensorflow_compression/python/layers/soft_round.py @@ -15,7 +15,7 @@ """Layers for soft rounding.""" import tensorflow as tf -from tensorflow_compression.python.ops import soft_round_ops +from tensorflow_compression.python.ops import round_ops __all__ = [ @@ -45,8 +45,7 @@ def __init__(self, super().__init__(**kwargs) self._alpha = alpha self._transform = ( - soft_round_ops.soft_round_inverse - if inverse else soft_round_ops.soft_round) + round_ops.soft_round_inverse if inverse else round_ops.soft_round) def call(self, inputs): outputs = self._transform(inputs, self._alpha) @@ -66,8 +65,7 @@ def __init__(self, self._alpha = alpha def call(self, inputs): - return soft_round_ops.soft_round_conditional_mean( - inputs, alpha=self._alpha) + return round_ops.soft_round_conditional_mean(inputs, alpha=self._alpha) def compute_output_shape(self, input_shape): return input_shape diff --git a/tensorflow_compression/python/layers/soft_round_test.py b/tensorflow_compression/python/layers/soft_round_test.py index 043657e..fdfd17c 100644 --- a/tensorflow_compression/python/layers/soft_round_test.py +++ b/tensorflow_compression/python/layers/soft_round_test.py @@ -16,7 +16,7 @@ import tensorflow as tf from tensorflow_compression.python.layers import soft_round -from tensorflow_compression.python.ops import soft_round_ops +from tensorflow_compression.python.ops import round_ops class SoftRoundTest(tf.test.TestCase): @@ -32,16 +32,14 @@ def test_soft_round_layer_soft_rounds(self): layer = soft_round.SoftRound(alpha=alpha) x = tf.linspace(-5.0, 5.0, num=50) y = layer(x) - self.assertAllClose(y, - soft_round_ops.soft_round(x, alpha=alpha)) + self.assertAllClose(y, round_ops.soft_round(x, alpha=alpha)) def test_soft_round_layer_inverse_inverse_soft_rounds(self): alpha = 5.0 layer = soft_round.SoftRound(alpha=alpha, inverse=True) x = tf.linspace(-5.0, 5.0, num=50) y = layer(x) - self.assertAllClose( - y, soft_round_ops.soft_round_inverse(x, alpha=alpha)) + self.assertAllClose(y, round_ops.soft_round_inverse(x, alpha=alpha)) def test_conditional_mean_takes_conditional_mean(self): alpha = 5.0 @@ -49,7 +47,7 @@ def test_conditional_mean_takes_conditional_mean(self): x = tf.linspace(-5.0, 5.0, num=50) y = layer(x) self.assertAllClose( - y, soft_round_ops.soft_round_conditional_mean(x, alpha=alpha)) + y, round_ops.soft_round_conditional_mean(x, alpha=alpha)) if __name__ == "__main__": diff --git a/tensorflow_compression/python/ops/BUILD b/tensorflow_compression/python/ops/BUILD index 4fe2fc9..8c25ca5 100644 --- a/tensorflow_compression/python/ops/BUILD +++ b/tensorflow_compression/python/ops/BUILD @@ -23,7 +23,7 @@ py_test( python_version = "PY3", deps = [ ":math_ops", - ":soft_round_ops", + ":round_ops", ], ) @@ -48,16 +48,16 @@ py_test( ) py_library( - name = "soft_round_ops", - srcs = ["soft_round_ops.py"], + name = "round_ops", + srcs = ["round_ops.py"], srcs_version = "PY3", ) py_test( - name = "soft_round_ops_test", - srcs = ["soft_round_ops_test.py"], + name = "round_ops_test", + srcs = ["round_ops_test.py"], python_version = "PY3", - deps = [":soft_round_ops"], + deps = [":round_ops"], ) filegroup( diff --git a/tensorflow_compression/python/ops/math_ops_test.py b/tensorflow_compression/python/ops/math_ops_test.py index b183a7f..7a06e53 100644 --- a/tensorflow_compression/python/ops/math_ops_test.py +++ b/tensorflow_compression/python/ops/math_ops_test.py @@ -18,7 +18,7 @@ import scipy.stats import tensorflow as tf from tensorflow_compression.python.ops import math_ops -from tensorflow_compression.python.ops import soft_round_ops +from tensorflow_compression.python.ops import round_ops class MathTest(tf.test.TestCase, parameterized.TestCase): @@ -86,7 +86,7 @@ def test_perturb_and_apply_noise(self): self.assertGreater(p, 1e-6) def test_perturb_and_apply_gradient_soft_round(self): - f = soft_round_ops.soft_round + f = round_ops.soft_round x = tf.linspace(-2.0, 2.0, 200) temperature = 7.0 with tf.GradientTape(persistent=True) as g: diff --git a/tensorflow_compression/python/ops/soft_round_ops.py b/tensorflow_compression/python/ops/round_ops.py similarity index 89% rename from tensorflow_compression/python/ops/soft_round_ops.py rename to tensorflow_compression/python/ops/round_ops.py index 65eab30..c521bc3 100644 --- a/tensorflow_compression/python/ops/soft_round_ops.py +++ b/tensorflow_compression/python/ops/round_ops.py @@ -18,12 +18,31 @@ __all__ = [ + "round_st", "soft_round", "soft_round_inverse", "soft_round_conditional_mean", ] +@tf.custom_gradient +def _round_st_no_offset(inputs): + return tf.round(inputs), lambda x: x + + +@tf.custom_gradient +def _round_st_offset(inputs, offset): + return tf.round(inputs - offset) + offset, lambda x: (x, None) + + +def round_st(inputs, offset=None): + """Straight-through round with optional quantization offset.""" + if offset is None: + return _round_st_no_offset(inputs) + else: + return _round_st_offset(inputs, offset) + + def soft_round(x, alpha, eps=1e-3): """Differentiable approximation to round(). diff --git a/tensorflow_compression/python/ops/soft_round_ops_test.py b/tensorflow_compression/python/ops/round_ops_test.py similarity index 85% rename from tensorflow_compression/python/ops/soft_round_ops_test.py rename to tensorflow_compression/python/ops/round_ops_test.py index db9d4ca..6e9ad7a 100644 --- a/tensorflow_compression/python/ops/soft_round_ops_test.py +++ b/tensorflow_compression/python/ops/round_ops_test.py @@ -17,46 +17,46 @@ from absl.testing import parameterized import tensorflow as tf -from tensorflow_compression.python.ops import soft_round_ops +from tensorflow_compression.python.ops import round_ops class SoftRoundTest(tf.test.TestCase, parameterized.TestCase): def test_soft_round_small_alpha_is_identity(self): x = tf.linspace(-2., 2., 50) - y = soft_round_ops.soft_round(x, alpha=1e-13) + y = round_ops.soft_round(x, alpha=1e-13) self.assertAllClose(x, y) def test_soft_round_large_alpha_is_round(self): # We don't care what happens exactly near half-integer values: for offset in range(-5, 5): x = tf.linspace(offset - 0.499, offset + 0.499, 100) - y = soft_round_ops.soft_round(x, alpha=2000.0) + y = round_ops.soft_round(x, alpha=2000.0) self.assertAllClose(tf.round(x), y, atol=0.02) def test_soft_inverse_round_small_alpha_is_identity(self): x = tf.linspace(-2., 2., 50) - y = soft_round_ops.soft_round_inverse(x, alpha=1e-13) + y = round_ops.soft_round_inverse(x, alpha=1e-13) self.assertAllEqual(x, y) def test_soft_inverse_is_actual_inverse(self): x = tf.constant([-1.25, -0.75, 0.75, 1.25], dtype=tf.float32) - y = soft_round_ops.soft_round(x, alpha=2.0) - x2 = soft_round_ops.soft_round_inverse(y, alpha=2.0) + y = round_ops.soft_round(x, alpha=2.0) + x2 = round_ops.soft_round_inverse(y, alpha=2.0) self.assertAllClose(x, x2) def test_soft_round_inverse_large_alpha_is_ceil_minus_half(self): # We don't care what happens exactly near integer values: for offset in range(-5, 5): x = tf.linspace(offset + 0.001, offset + 0.999, 100) - y = soft_round_ops.soft_round_inverse(x, alpha=5000.0) + y = round_ops.soft_round_inverse(x, alpha=5000.0) self.assertAllClose(tf.math.ceil(x) - 0.5, y, atol=0.001) def test_conditional_mean_large_alpha_is_round(self): # We don't care what happens exactly near integer values: for offset in range(-5, 5): x = tf.linspace(offset + 0.001, offset + 0.999, 100) - y = soft_round_ops.soft_round_conditional_mean(x, alpha=5000.0) + y = round_ops.soft_round_conditional_mean(x, alpha=5000.0) self.assertAllClose(tf.math.round(x), y, atol=0.001) @parameterized.parameters(0., 1e-6, 1e-2, 5., 1e6) @@ -64,7 +64,7 @@ def test_soft_round_values_and_gradients_are_finite(self, alpha): x = tf.linspace(0., 1., 11) # covers exact integers and half-integers with tf.GradientTape() as tape: tape.watch(x) - y = soft_round_ops.soft_round(x, alpha=alpha) + y = round_ops.soft_round(x, alpha=alpha) dy = tape.gradient(y, x) self.assertAllEqual(tf.math.is_finite(y), tf.ones(x.shape, dtype=bool)) self.assertAllEqual(tf.math.is_finite(dy), tf.ones(x.shape, dtype=bool)) @@ -74,7 +74,7 @@ def test_soft_round_inverse_values_and_gradients_are_finite(self, alpha): x = tf.linspace(-.5, .5, 11) # covers exact integers and half-integers with tf.GradientTape() as tape: tape.watch(x) - y = soft_round_ops.soft_round_inverse(x, alpha=alpha) + y = round_ops.soft_round_inverse(x, alpha=alpha) dy = tape.gradient(y, x) self.assertAllEqual(tf.math.is_finite(y), tf.ones(x.shape, dtype=bool)) is_finite = tf.math.is_finite(dy)