Skip to content

Commit

Permalink
Split out a double-check-cache job, so that we double check the cache…
Browse files Browse the repository at this point in the history
… before starting either rsc or zinc work, and then never again (otherwise hitting the cache might race with the ongoing work, which we are unable to cancel).
  • Loading branch information
stuhood committed Aug 29, 2019
1 parent df6a2fc commit 76c096c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 61 deletions.
61 changes: 38 additions & 23 deletions src/python/pants/backend/jvm/tasks/jvm_compile/jvm_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,53 +711,64 @@ def _upstream_analysis(self, compile_contexts, classpath_entries):
else:
yield compile_context.classes_dir.path, compile_context.analysis_file

def exec_graph_double_check_cache_key_for_target(self, target):
return 'double_check_cache({})'.format(target.address.spec)

def exec_graph_key_for_target(self, compile_target):
return "compile({})".format(compile_target.address.spec)

def _create_compile_jobs(self, compile_contexts, invalid_targets, invalid_vts, classpath_product):
class Counter:
def __init__(self, size, initial=0):
def __init__(self, size=0):
self.size = size
self.count = initial
self.count = 0

def __call__(self):
self.count += 1
return self.count

def increment_size(self, by=1):
self.size += by

def format_length(self):
return len(str(self.size))
counter = Counter(len(invalid_vts))

jobs = []
counter = Counter()

jobs.extend(self.pre_compile_jobs(counter))
invalid_target_set = set(invalid_targets)
for ivts in invalid_vts:
# Invalidated targets are a subset of relevant targets: get the context for this one.
compile_target = ivts.target
invalid_dependencies = self._collect_invalid_compile_dependencies(compile_target,
invalid_target_set)

jobs.extend(
self.create_compile_jobs(compile_target, compile_contexts, invalid_dependencies, ivts,
counter, classpath_product))
new_jobs, new_count = self.create_compile_jobs(
compile_target, compile_contexts, invalid_dependencies, ivts, counter, classpath_product)
jobs.extend(new_jobs)
counter.increment_size(by=new_count)

counter.size = len(jobs)
return jobs

def pre_compile_jobs(self, counter):
"""Override this to provide jobs that are not related to particular targets.
This is only called when there are invalid targets."""
return []

def create_compile_jobs(self, compile_target, all_compile_contexts, invalid_dependencies, ivts,
counter, classpath_product):
"""Return a list of jobs, and a count of those jobs that represent meaningful ("countable") work."""

context_for_target = all_compile_contexts[compile_target]
compile_context = self.select_runtime_context(context_for_target)

job = Job(self.exec_graph_key_for_target(compile_target),
compile_deps = [self.exec_graph_key_for_target(target) for target in invalid_dependencies]

# The cache checking job doesn't technically have any dependencies, but we want to delay it
# until immediately before we would otherwise try compiling, so we indicate that it depends on
# all compile dependencies.
double_check_cache_job = Job(self.exec_graph_double_check_cache_key_for_target(compile_target),
functools.partial(self._default_double_check_cache_for_vts, ivts),
compile_deps)
# The compile job depends on the cache check job. This decomposition is necessary in order to
# support more complex situations where compilation runs multiple jobs in parallel, and wants to
# double check the cache before starting any of them.
compile_job = Job(self.exec_graph_key_for_target(compile_target),
functools.partial(
self._default_work_for_vts,
ivts,
Expand All @@ -766,15 +777,15 @@ def create_compile_jobs(self, compile_target, all_compile_contexts, invalid_depe
counter,
all_compile_contexts,
classpath_product),
[self.exec_graph_key_for_target(target) for target in invalid_dependencies],
[double_check_cache_job.key] + compile_deps,
self._size_estimator(compile_context.sources),
# If compilation and analysis work succeeds, validate the vts.
# Otherwise, fail it.
on_success=ivts.update,
on_failure=ivts.force_invalidate)
return [job]
return ([compile_job], 1)

def check_cache(self, vts, counter):
def check_cache(self, vts):
"""Manually checks the artifact cache (usually immediately before compilation.)
Returns true if the cache was hit successfully, indicating that no compilation is necessary.
Expand All @@ -790,7 +801,6 @@ def check_cache(self, vts, counter):
'Cache returned unexpected target: {} vs {}'.format(cached_vts, [vts])
)
self.context.log.info('Hit cache during double check for {}'.format(vts.target.address.spec))
counter()
return True

def should_compile_incrementally(self, vts, ctx):
Expand Down Expand Up @@ -916,13 +926,18 @@ def _get_jvm_distribution(self):
self.HERMETIC: lambda: self._HermeticDistribution('.jdk', local_distribution),
})()

def _default_double_check_cache_for_vts(self, vts):
# Double check the cache before beginning compilation
if self.check_cache(vts):
vts.update()

def _default_work_for_vts(self, vts, ctx, input_classpath_product_key, counter, all_compile_contexts, output_classpath_product):
progress_message = ctx.target.address.spec

# Double check the cache before beginning compilation
hit_cache = self.check_cache(vts, counter)

if not hit_cache:
# See whether the cache-doublecheck job hit the cache: if so, noop: otherwise, compile.
if vts.valid:
counter()
else:
# Compute the compile classpath for this target.
dependency_cp_entries = self._zinc.compile_classpath_entries(
input_classpath_product_key,
Expand Down
69 changes: 31 additions & 38 deletions src/python/pants/backend/jvm/tasks/jvm_compile/rsc/rsc_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,25 +292,6 @@ def _zinc_key_for_target(self, target, workflow):
def _write_to_cache_key_for_target(self, target):
return 'write_to_cache({})'.format(target.address.spec)

def _check_cache_before_work(self, work_str, vts, ctx, counter, debug = False, work_fn = lambda: None):
hit_cache = self.check_cache(vts, counter)

if not hit_cache:
counter_val = str(counter()).rjust(counter.format_length(), ' ')
counter_str = '[{}/{}] '.format(counter_val, counter.size)
log_fn = self.context.log.debug if debug else self.context.log.info
log_fn(
counter_str,
f'{work_str} ',
items_to_report_element(ctx.sources, '{} source'.format(self.name())),
' in ',
items_to_report_element([t.address.reference() for t in vts.targets], 'target'),
' (',
ctx.target.address.spec,
').')

work_fn()

def create_compile_jobs(self,
compile_target,
compile_contexts,
Expand All @@ -323,7 +304,19 @@ def work_for_vts_rsc(vts, ctx):
target = ctx.target
tgt, = vts.targets

def work_fn():
# If we didn't hit the cache in the cache job, run rsc.
if not vts.valid:
counter_val = str(counter()).rjust(counter.format_length(), ' ')
counter_str = '[{}/{}] '.format(counter_val, counter.size)
self.context.log.info(
counter_str,
'Rsc-ing ',
items_to_report_element(ctx.sources, '{} source'.format(self.name())),
' in ',
items_to_report_element([t.address.reference() for t in vts.targets], 'target'),
' (',
ctx.target.address.spec,
').')
# This does the following
# - Collect the rsc classpath elements, including zinc compiles of rsc incompatible targets
# and rsc compiles of rsc compatible targets.
Expand Down Expand Up @@ -391,15 +384,9 @@ def nonhermetic_digest_classpath():
'rsc'
)

# Double check the cache before beginning compilation
self._check_cache_before_work('Rsc-ing', vts, ctx, counter, work_fn=work_fn)

# Update the products with the latest classes.
self.register_extra_products_from_contexts([ctx.target], compile_contexts)

def work_for_vts_write_to_cache(vts, ctx):
self._check_cache_before_work('Writing to cache for', vts, ctx, counter, debug=True)

### Create Jobs for ExecutionGraph
rsc_jobs = []
zinc_jobs = []
Expand All @@ -420,6 +407,13 @@ def all_zinc_rsc_invalid_dep_keys(invalid_deps):
# Rely on the results of zinc compiles for zinc-compatible targets
yield self._key_for_target_as_dep(tgt, tgt_rsc_cc.workflow)

# As in JvmCompile.create_compile_jobs, we create a cache-double-check job that all "real" work
# depends on. It depends on completion of the same dependencies as the rsc job in order to run
# as late as possible, while still running before rsc or zinc.
double_check_cache_job = Job(self.exec_graph_double_check_cache_key_for_target(compile_target),
functools.partial(self._default_double_check_cache_for_vts, ivts),
dependencies=list(all_zinc_rsc_invalid_dep_keys(invalid_dependencies)))

def make_rsc_job(target, dep_targets):
return Job(
key=self._rsc_key_for_target(target),
Expand All @@ -432,7 +426,7 @@ def make_rsc_job(target, dep_targets):
),
# The rsc jobs depend on other rsc jobs, and on zinc jobs for targets that are not
# processed by rsc.
dependencies=list(all_zinc_rsc_invalid_dep_keys(dep_targets)),
dependencies=[double_check_cache_job.key] + list(all_zinc_rsc_invalid_dep_keys(dep_targets)),
size=self._size_estimator(rsc_compile_context.sources),
)

Expand All @@ -453,7 +447,7 @@ def make_zinc_job(target, input_product_key, output_products, dep_keys):
counter,
compile_contexts,
CompositeProductAdder(*output_products)),
dependencies=list(dep_keys),
dependencies=[double_check_cache_job.key] + list(dep_keys),
size=self._size_estimator(zinc_compile_context.sources),
)

Expand Down Expand Up @@ -520,24 +514,23 @@ def record(k, v):
})()

all_jobs = rsc_jobs + zinc_jobs
real_job_count = len(all_jobs)

if all_jobs:
if real_job_count > 0:
# Create a job that depends on all real work have completed that will eagerly write to the
#cache by calling `vt.update()`.
write_to_cache_job = Job(
key=self._write_to_cache_key_for_target(compile_target),
fn=functools.partial(
work_for_vts_write_to_cache,
ivts,
rsc_compile_context,
),
fn=ivts.update,
dependencies=[job.key for job in all_jobs],
run_asap=True,
# If compilation and analysis work succeeds, validate the vts.
# Otherwise, fail it.
on_success=ivts.update,
on_failure=ivts.force_invalidate)
all_jobs.append(write_to_cache_job)
# And, since there is work to do, record the double_check_cache_job to check the cache
# immediately before that work.
all_jobs.append(double_check_cache_job)

return all_jobs
return (all_jobs, real_job_count)

class RscZincMergedCompileContexts(datatype([
('rsc_cc', RscCompileContext),
Expand Down

0 comments on commit 76c096c

Please sign in to comment.