Skip to content

Commit

Permalink
Revises handling of quantization offsets.
Browse files Browse the repository at this point in the history
- The quantization offset for DeepFactorized is now determined numerically
  instead of assuming zero.
- For batched entropy models, the `non_integer_offset` argument controls
  whether the quantization offset heuristic is used or not (as before).
- For location-scale family entropy models, always quantize to integers modulo
  location parameter of the prior distribution.
- For general indexed entropy models, do not use quantization offset heuristic,
  and always quantize to integers.
- Universal entropy models use their own logic, as before.

The above is accomplished by some refactoring:
- The logic for creating the range coding tables is moved from the initializer
  of the base class in continuous_base.py to the initializers of the subclasses.
  This makes it possible to streamline the building of the range coding tables
  and make that logic available as a private method to be called by subclasses
  instead.
- Models in in universal.py now depend directly on the base class. This way,
  they don't need to inherit the quantization offset logic and can implement
  their own.
Both of these changes make it possible to remove indirection. They also free
parent classes from having to implement functionality they don't need, and child
classes from inheriting functionality that doesn't make sense for them.

PiperOrigin-RevId: 420564225
Change-Id: I57cdd9627b83db3a2455a23d9481ccb23309f957
  • Loading branch information
Johannes Ballé authored and copybara-github committed Jan 9, 2022
1 parent c60a5a9 commit edb8df5
Show file tree
Hide file tree
Showing 24 changed files with 826 additions and 516 deletions.
2 changes: 1 addition & 1 deletion BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ py_library(
"//tensorflow_compression/python/ops:gen_ops",
"//tensorflow_compression/python/ops:math_ops",
"//tensorflow_compression/python/ops:padding_ops",
"//tensorflow_compression/python/ops:soft_round_ops",
"//tensorflow_compression/python/ops:round_ops",
"//tensorflow_compression/python/util:packed_tensors",
],
)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_compression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from tensorflow_compression.python.ops.gen_ops import *
from tensorflow_compression.python.ops.math_ops import *
from tensorflow_compression.python.ops.padding_ops import *
from tensorflow_compression.python.ops.soft_round_ops import *
from tensorflow_compression.python.ops.round_ops import *

from tensorflow_compression.python.util.packed_tensors import *
# pylint: enable=wildcard-import
2 changes: 1 addition & 1 deletion tensorflow_compression/all_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from tensorflow_compression.python.ops.math_ops_test import *
from tensorflow_compression.python.ops.padding_ops_test import *
from tensorflow_compression.python.ops.range_coding_ops_test import *
from tensorflow_compression.python.ops.soft_round_ops_test import *
from tensorflow_compression.python.ops.round_ops_test import *

from tensorflow_compression.python.util.packed_tensors_test import *
# pylint: enable=wildcard-import
Expand Down
3 changes: 1 addition & 2 deletions tensorflow_compression/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ py_library(
":deep_factorized",
":helpers",
":uniform_noise",
"//tensorflow_compression/python/ops:soft_round_ops",
"//tensorflow_compression/python/ops:round_ops",
],
)

Expand All @@ -76,7 +76,6 @@ py_test(
deps = [
":deep_factorized",
":round_adapters",
"//tensorflow_compression/python/ops:soft_round_ops",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,18 @@ def _log_prob(self, inputs):
return log_s_logits + log_s_neg_logits + tf.math.log(dlogits)

def _quantization_offset(self):
return tf.constant(0, dtype=self.dtype)
return helpers.estimate_tails(
self._logits_cumulative, 0., self.batch_shape_tensor(), self.dtype)

def _lower_tail(self, tail_mass):
logits = tf.math.log(tail_mass / 2 / (1. - tail_mass / 2))
logits = tf.math.log(
tf.cast(tail_mass / 2 / (1. - tail_mass / 2), self.dtype))
return helpers.estimate_tails(
self._logits_cumulative, logits, self.batch_shape_tensor(), self.dtype)

def _upper_tail(self, tail_mass):
logits = -tf.math.log(tail_mass / 2 / (1. - tail_mass / 2))
logits = -tf.math.log(
tf.cast(tail_mass / 2 / (1. - tail_mass / 2), self.dtype))
return helpers.estimate_tails(
self._logits_cumulative, logits, self.batch_shape_tensor(), self.dtype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,6 @@ def test_uniform_is_special_case(self):
x = tf.linspace(-1., 1., 10)
self.assertAllClose(df.prob(x), [0, 0, 0, 1, 1, 1, 1, 0, 0, 0])

def test_quantization_offset_is_zero(self):
df = deep_factorized.NoisyDeepFactorized()
self.assertEqual(helpers.quantization_offset(df), 0)

def test_tails_are_in_order(self):
df = deep_factorized.NoisyDeepFactorized()
lower_tail = helpers.lower_tail(df, 2**-8)
Expand Down
8 changes: 5 additions & 3 deletions tensorflow_compression/python/distributions/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,10 @@ def lower_tail(distribution, tail_mass):
tail = distribution.quantile(tail_mass / 2)
except NotImplementedError:
try:
target = tf.math.log(tf.cast(tail_mass / 2, distribution.dtype))
tail = estimate_tails(
distribution.log_cdf, tf.math.log(tail_mass / 2),
distribution.batch_shape_tensor(), distribution.dtype)
distribution.log_cdf, target, distribution.batch_shape_tensor(),
distribution.dtype)
except NotImplementedError:
raise NotImplementedError(
"`distribution` must implement `_lower_tail()`, `quantile()`, or "
Expand Down Expand Up @@ -193,8 +194,9 @@ def upper_tail(distribution, tail_mass):
tail = distribution.quantile(1 - tail_mass / 2)
except NotImplementedError:
try:
target = tf.math.log(tf.cast(tail_mass / 2, distribution.dtype))
tail = estimate_tails(
distribution.log_survival_function, tf.math.log(tail_mass / 2),
distribution.log_survival_function, target,
distribution.batch_shape_tensor(), distribution.dtype)
except NotImplementedError:
raise NotImplementedError(
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_compression/python/distributions/round_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tensorflow_compression.python.distributions import deep_factorized
from tensorflow_compression.python.distributions import helpers
from tensorflow_compression.python.distributions import uniform_noise
from tensorflow_compression.python.ops import soft_round_ops
from tensorflow_compression.python.ops import round_ops


__all__ = [
Expand Down Expand Up @@ -239,10 +239,10 @@ def __init__(self, base, alpha, name="SoftRoundAdapter"):
self._alpha = alpha

def transform(self, x):
return soft_round_ops.soft_round(x, self._alpha)
return round_ops.soft_round(x, self._alpha)

def inverse_transform(self, y):
return soft_round_ops.soft_round_inverse(y, self._alpha)
return round_ops.soft_round_inverse(y, self._alpha)


class NoisySoftRoundAdapter(uniform_noise.UniformNoiseAdapter):
Expand Down
42 changes: 14 additions & 28 deletions tensorflow_compression/python/distributions/round_adapters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import tensorflow_probability as tfp
from tensorflow_compression.python.distributions import deep_factorized
from tensorflow_compression.python.distributions import round_adapters
from tensorflow_compression.python.ops import soft_round_ops


def _test_log_prob_gradient_is_bounded(self, dist_cls, values, params=()):
Expand All @@ -42,44 +41,42 @@ class AdaptersTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
("softround_deepfactorized",
lambda d: round_adapters.SoftRoundAdapter(d, alpha=5.0),
deep_factorized.DeepFactorized, 0.0),
deep_factorized.DeepFactorized),
("softround_logistic",
lambda d: round_adapters.SoftRoundAdapter(d, alpha=5.0),
lambda: tfp.distributions.Logistic(loc=10.3, scale=1.5),
lambda: soft_round_ops.soft_round(0.3, alpha=5.0)),
lambda: tfp.distributions.Logistic(loc=10.3, scale=1.5)),
("softround_normal",
lambda d: round_adapters.SoftRoundAdapter(d, alpha=4.0),
lambda: tfp.distributions.Normal(loc=10.4, scale=1.5),
lambda: soft_round_ops.soft_round(0.4, alpha=4.0)),
lambda: tfp.distributions.Normal(loc=10.4, scale=1.5)),
("noisysoftround_deepfactorized",
lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0),
deep_factorized.DeepFactorized, 0.0),
deep_factorized.DeepFactorized),
("noisysoftround_logistic",
lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0),
lambda: tfp.distributions.Logistic(loc=10, scale=1.5), 0.0),
lambda: tfp.distributions.Logistic(loc=10, scale=1.5)),
("noisysoftround_normal",
lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0),
lambda: tfp.distributions.Normal(loc=10, scale=1.5), 0.0),
lambda: tfp.distributions.Normal(loc=10, scale=1.5)),
("round_deepfactorized",
round_adapters.RoundAdapter,
lambda: deep_factorized.DeepFactorized(init_scale=1.0), 0.0),
lambda: deep_factorized.DeepFactorized(init_scale=1.0)),
("round_logistic",
round_adapters.RoundAdapter,
lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5), 0.0),
lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5)),
("round_normal",
round_adapters.RoundAdapter,
lambda: tfp.distributions.Normal(loc=1.5, scale=1.5), 0.0),
lambda: tfp.distributions.Normal(loc=1.5, scale=1.5)),
("noisyround_deepfactorized",
round_adapters.NoisyRoundAdapter,
lambda: deep_factorized.DeepFactorized(init_scale=1.0), 0.0),
lambda: deep_factorized.DeepFactorized(init_scale=1.0)),
("noisyround_logistic",
round_adapters.NoisyRoundAdapter,
lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5), 0.0),
lambda: tfp.distributions.Logistic(loc=1.5, scale=1.5)),
("noisyround_normal",
round_adapters.NoisyRoundAdapter,
lambda: tfp.distributions.Normal(loc=1.5, scale=1.5), 0.0),
lambda: tfp.distributions.Normal(loc=1.5, scale=1.5)),
)
def test_tails_and_offset(self, adapter, distribution, expected_offset):
def test_tails(self, adapter, distribution):
dist = adapter(distribution())
lower_tail = dist._lower_tail(2**-8)
try:
Expand All @@ -98,12 +95,6 @@ def test_tails_and_offset(self, adapter, distribution, expected_offset):
self.assertLessEqual(right_mass, 2**-8)

self.assertGreater(upper_tail, lower_tail)
offset = dist._quantization_offset()
if not isinstance(expected_offset, float):
# We cannot run tf inside the parameterized test declaration, hence
# non-float values are wrapped in a lambda.
expected_offset = expected_offset()
self.assertAllClose(offset, expected_offset)

@parameterized.named_parameters(
("softround_logistic",
Expand Down Expand Up @@ -210,16 +201,11 @@ def test_sampling_works(self):
sample = dist.sample((5, 4))
self.assertEqual(sample.shape, (5, 4, 2))

def test_tails_and_offset_are_in_order(self):
def test_tails_are_in_order(self):
dist = self.dist_cls(loc=10, scale=1.5)
offset = dist._quantization_offset()
lower_tail = dist._lower_tail(2**-8)
upper_tail = dist._upper_tail(2**-8)
self.assertGreater(upper_tail, lower_tail)
if offset:
# If quantization offset is 0.0, it doesn't need to be between the tails.
self.assertGreater(upper_tail, offset)
self.assertGreater(offset, lower_tail)

def test_stats_throw_error(self):
dist = self.dist_cls(loc=1, scale=2)
Expand Down
9 changes: 5 additions & 4 deletions tensorflow_compression/python/entropy_models/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ py_library(
"//tensorflow_compression/python/distributions:helpers",
"//tensorflow_compression/python/ops:gen_ops",
"//tensorflow_compression/python/ops:math_ops",
"//tensorflow_compression/python/ops:round_ops",
],
)

Expand All @@ -42,16 +43,17 @@ py_library(
srcs_version = "PY3",
deps = [
":continuous_base",
"//tensorflow_compression/python/distributions:helpers",
"//tensorflow_compression/python/ops:gen_ops",
"//tensorflow_compression/python/ops:math_ops",
"//tensorflow_compression/python/ops:round_ops",
],
)

py_test(
name = "continuous_indexed_test",
srcs = ["continuous_indexed_test.py"],
python_version = "PY3",
shard_count = 5,
deps = [
":continuous_indexed",
"//tensorflow_compression/python/distributions:uniform_noise",
Expand All @@ -63,15 +65,14 @@ py_library(
srcs = ["universal.py"],
srcs_version = "PY3",
deps = [
":continuous_batched",
":continuous_indexed",
":continuous_base",
"//tensorflow_compression/python/ops:gen_ops",
"//tensorflow_compression/python/ops:math_ops",
],
)

py_test(
name = "universal_test",
timeout = "long",
srcs = ["universal_test.py"],
python_version = "PY3",
shard_count = 3,
Expand Down
Loading

0 comments on commit edb8df5

Please sign in to comment.