Jitted function runs slower than non-jitted function for EM updates. #23822
-
Edit: This may be OS specific; see additional comment below. I'm observing that a jitted function takes around 60% longer to complete than a non-jitted function. There are some discussions on GitHub and StackOverflow related to the same observation (e.g., very small functions where the overhead is larger than the benefit from jit, placing a large number of Python objects on the device, or compile overhead due to Python loops), but I don't think I'm in any of these settings. For context, I'm using jax to estimate the factors of a tensor decomposition model using classic EM-style updates (rather than gradient-based optimization). The motivation to jit is that I would like to use and I try to minimize the L2 loss between The main update function is as follows (a full example is here). For the larger analysis, I'm using shrinkage priors on all components of the model and use variational Bayes updates with a mean-field approximation for the posterior. But the function below exhibits the same behavior and is much more readable. def update(factors, i, j, k, y):
"""
Update the factors given the COO representation of observations.
Args:
factors: Mapping from factor names to jax arrays.
i: Indices of obs. along first tensor dimension with shape `(n_obs,)`.
j: Indices of obs. along second tensor dimension with shape `(n_obs,)`.
k: Indices of obs. along third tensor dimension with shape `(n_obs,)`.
y: Observations at indices `(i, j, k)` with shape `(n_obs,)`.
Returns:
Updated factors and predictions as a tuple.
"""
# Create a shallow copy to ensure the function is pure.
factors = factors.copy()
# Indices to expand each factor to the same size as the observations `y`.
indices = {
"a": (i,),
"b": (j,),
"c": (k,),
"A": (i, j),
"B": (j, k),
"C": (k, i),
}
# Expand the factors so we can easily construct an estimate of `y`.
summands = {key: factors[key][idx] for key, idx in indices.items()}
# First update the grand mean `mu` separately because it doesn't have indices.
y_hat = sum(summands.values())
mu = (y - y_hat).mean()
factors["mu"] = mu
summands["mu"] = mu
# Iterate over all factors and update them.
for key, idx in indices.items():
# Pop the factor we're currently updating.
summands.pop(key)
# Evaluate the number of observations per element of the factor. We add on
# 0.001 to avoid division by zero. This is just the prior precision in a
# Bayesian context.
precision = 0.001 + jnp.zeros_like(factors[key]).at[idx].add(1)
# Evaluate the totals for each element and divide by the precision to get a
# point estimate.
y_hat = sum(summands.values())
residuals = y - y_hat
prod = jnp.zeros_like(precision).at[idx].add(residuals)
factor = prod / precision
# Update the factor and add to the summands.
factors[key] = factor
summands[key] = factor[idx]
y_hat = sum(summands.values())
return factors, y_hat Timings are as follows (all run on the CPU of a 2020 MacBook Pro with M1 chip). I'm using tensor dimensions >>> %timeit jax.block_until_ready(update(factors, i, j, k, y))
159 ms ± 2.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> # Jit and run once to remove compilation overhead in timing.
>>> jitted = jax.jit(update)
>>> jax.block_until_ready(jitted(factors, i, j, k, y))
>>> %timeit jax.block_until_ready(jitted(factors, i, j, k, y))
266 ms ± 11.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) I expected the jitted version to be faster, e.g., because the My environment is as follows. >>> jax.print_environment_info()
jax: 0.4.33
jaxlib: 0.4.33
numpy: 2.1.1
python: 3.11.5 (main, Dec 8 2023, 17:04:09) [Clang 15.0.0 (clang-1500.0.40.1)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Tills-MacBook-Pro.local', release='24.0.0', version='Darwin Kernel Version 24.0.0: Mon Aug 12 20:49:48 PDT 2024; root:xnu-11215.1.10~2/RELEASE_ARM64_T8103', machine='arm64') |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 10 replies
-
Update: This may be OS specific (macOS Sequoia on my machine), because the timings are very different when running on a Colab CPU: The jitted function is about 4.5x faster. >>> %timeit jax.block_until_ready(update(factors, i, j, k, y))
84.6 ms ± 11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> # Jit and run once to remove compilation overhead in timing.
>>> jitted = jax.jit(update)
>>> jax.block_until_ready(jitted(factors, i, j, k, y))
>>> %timeit jax.block_until_ready(jitted(factors, i, j, k, y))
18.6 ms ± 2.95 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) |
Beta Was this translation helpful? Give feedback.
Thanks for trying that. For now perhaps you can set the following environment variable as a workaround for v0.4.33:
I expect that we'll be able to come up with a longer term solution for use cases like this, but I don't know the answer yet!