Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split out a double-check-cache job for jvm/rsc compile #8221

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions src/python/pants/backend/jvm/tasks/jvm_compile/jvm_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,53 +711,64 @@ def _upstream_analysis(self, compile_contexts, classpath_entries):
else:
yield compile_context.classes_dir.path, compile_context.analysis_file

def exec_graph_double_check_cache_key_for_target(self, target):
return 'double_check_cache({})'.format(target.address.spec)

def exec_graph_key_for_target(self, compile_target):
return "compile({})".format(compile_target.address.spec)

def _create_compile_jobs(self, compile_contexts, invalid_targets, invalid_vts, classpath_product):
class Counter:
def __init__(self, size, initial=0):
def __init__(self, size=0):
self.size = size
self.count = initial
self.count = 0

def __call__(self):
self.count += 1
return self.count

def increment_size(self, by=1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be folded into the __call__ function above?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not without an awkward signature I think: this one increments size, while the other increments count.

self.size += by

def format_length(self):
return len(str(self.size))
counter = Counter(len(invalid_vts))

jobs = []
counter = Counter()

jobs.extend(self.pre_compile_jobs(counter))
invalid_target_set = set(invalid_targets)
for ivts in invalid_vts:
# Invalidated targets are a subset of relevant targets: get the context for this one.
compile_target = ivts.target
invalid_dependencies = self._collect_invalid_compile_dependencies(compile_target,
invalid_target_set)

jobs.extend(
self.create_compile_jobs(compile_target, compile_contexts, invalid_dependencies, ivts,
counter, classpath_product))
new_jobs, new_count = self.create_compile_jobs(
compile_target, compile_contexts, invalid_dependencies, ivts, counter, classpath_product)
jobs.extend(new_jobs)
counter.increment_size(by=new_count)

counter.size = len(jobs)
return jobs

def pre_compile_jobs(self, counter):
"""Override this to provide jobs that are not related to particular targets.

This is only called when there are invalid targets."""
return []

def create_compile_jobs(self, compile_target, all_compile_contexts, invalid_dependencies, ivts,
counter, classpath_product):
"""Return a list of jobs, and a count of those jobs that represent meaningful ("countable") work."""

context_for_target = all_compile_contexts[compile_target]
compile_context = self.select_runtime_context(context_for_target)

job = Job(self.exec_graph_key_for_target(compile_target),
compile_deps = [self.exec_graph_key_for_target(target) for target in invalid_dependencies]

# The cache checking job doesn't technically have any dependencies, but we want to delay it
# until immediately before we would otherwise try compiling, so we indicate that it depends on
# all compile dependencies.
double_check_cache_job = Job(self.exec_graph_double_check_cache_key_for_target(compile_target),
functools.partial(self._default_double_check_cache_for_vts, ivts),
compile_deps)
# The compile job depends on the cache check job. This decomposition is necessary in order to
# support more complex situations where compilation runs multiple jobs in parallel, and wants to
# double check the cache before starting any of them.
compile_job = Job(self.exec_graph_key_for_target(compile_target),
functools.partial(
self._default_work_for_vts,
ivts,
Expand All @@ -766,15 +777,15 @@ def create_compile_jobs(self, compile_target, all_compile_contexts, invalid_depe
counter,
all_compile_contexts,
classpath_product),
[self.exec_graph_key_for_target(target) for target in invalid_dependencies],
[double_check_cache_job.key] + compile_deps,
self._size_estimator(compile_context.sources),
# If compilation and analysis work succeeds, validate the vts.
# Otherwise, fail it.
on_success=ivts.update,
on_failure=ivts.force_invalidate)
return [job]
return ([double_check_cache_job, compile_job], 1)

def check_cache(self, vts, counter):
def check_cache(self, vts):
"""Manually checks the artifact cache (usually immediately before compilation.)

Returns true if the cache was hit successfully, indicating that no compilation is necessary.
Expand All @@ -790,7 +801,6 @@ def check_cache(self, vts, counter):
'Cache returned unexpected target: {} vs {}'.format(cached_vts, [vts])
)
self.context.log.info('Hit cache during double check for {}'.format(vts.target.address.spec))
counter()
return True

def should_compile_incrementally(self, vts, ctx):
Expand Down Expand Up @@ -916,13 +926,18 @@ def _get_jvm_distribution(self):
self.HERMETIC: lambda: self._HermeticDistribution('.jdk', local_distribution),
})()

def _default_double_check_cache_for_vts(self, vts):
# Double check the cache before beginning compilation
if self.check_cache(vts):
vts.update()

def _default_work_for_vts(self, vts, ctx, input_classpath_product_key, counter, all_compile_contexts, output_classpath_product):
progress_message = ctx.target.address.spec

# Double check the cache before beginning compilation
hit_cache = self.check_cache(vts, counter)

if not hit_cache:
# See whether the cache-doublecheck job hit the cache: if so, noop: otherwise, compile.
if vts.valid:
counter()
else:
# Compute the compile classpath for this target.
dependency_cp_entries = self._zinc.compile_classpath_entries(
input_classpath_product_key,
Expand Down
99 changes: 52 additions & 47 deletions src/python/pants/backend/jvm/tasks/jvm_compile/rsc/rsc_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,25 +292,6 @@ def _zinc_key_for_target(self, target, workflow):
def _write_to_cache_key_for_target(self, target):
return 'write_to_cache({})'.format(target.address.spec)

def _check_cache_before_work(self, work_str, vts, ctx, counter, debug = False, work_fn = lambda: None):
hit_cache = self.check_cache(vts, counter)

if not hit_cache:
counter_val = str(counter()).rjust(counter.format_length(), ' ')
counter_str = '[{}/{}] '.format(counter_val, counter.size)
log_fn = self.context.log.debug if debug else self.context.log.info
log_fn(
counter_str,
f'{work_str} ',
items_to_report_element(ctx.sources, '{} source'.format(self.name())),
' in ',
items_to_report_element([t.address.reference() for t in vts.targets], 'target'),
' (',
ctx.target.address.spec,
').')

work_fn()

def create_compile_jobs(self,
compile_target,
compile_contexts,
Expand All @@ -323,7 +304,19 @@ def work_for_vts_rsc(vts, ctx):
target = ctx.target
tgt, = vts.targets

def work_fn():
# If we didn't hit the cache in the cache job, run rsc.
if not vts.valid:
counter_val = str(counter()).rjust(counter.format_length(), ' ')
counter_str = '[{}/{}] '.format(counter_val, counter.size)
self.context.log.info(
counter_str,
'Rsc-ing ',
items_to_report_element(ctx.sources, '{} source'.format(self.name())),
' in ',
items_to_report_element([t.address.reference() for t in vts.targets], 'target'),
' (',
ctx.target.address.spec,
').')
# This does the following
# - Collect the rsc classpath elements, including zinc compiles of rsc incompatible targets
# and rsc compiles of rsc compatible targets.
Expand Down Expand Up @@ -391,16 +384,11 @@ def nonhermetic_digest_classpath():
'rsc'
)

# Double check the cache before beginning compilation
self._check_cache_before_work('Rsc-ing', vts, ctx, counter, work_fn=work_fn)

# Update the products with the latest classes.
self.register_extra_products_from_contexts([ctx.target], compile_contexts)

def work_for_vts_write_to_cache(vts, ctx):
self._check_cache_before_work('Writing to cache for', vts, ctx, counter, debug=True)

### Create Jobs for ExecutionGraph
cache_doublecheck_jobs = []
rsc_jobs = []
zinc_jobs = []

Expand All @@ -410,6 +398,8 @@ def work_for_vts_write_to_cache(vts, ctx):
rsc_compile_context = merged_compile_context.rsc_cc
zinc_compile_context = merged_compile_context.zinc_cc

cache_doublecheck_key = self.exec_graph_double_check_cache_key_for_target(compile_target)

def all_zinc_rsc_invalid_dep_keys(invalid_deps):
"""Get the rsc key for an rsc-and-zinc target, or the zinc key for a zinc-only target."""
for tgt in invalid_deps:
Expand All @@ -420,6 +410,14 @@ def all_zinc_rsc_invalid_dep_keys(invalid_deps):
# Rely on the results of zinc compiles for zinc-compatible targets
yield self._key_for_target_as_dep(tgt, tgt_rsc_cc.workflow)

def make_cache_doublecheck_job(dep_keys):
# As in JvmCompile.create_compile_jobs, we create a cache-double-check job that all "real" work
# depends on. It depends on completion of the same dependencies as the rsc job in order to run
# as late as possible, while still running before rsc or zinc.
return Job(cache_doublecheck_key,
functools.partial(self._default_double_check_cache_for_vts, ivts),
dependencies=list(dep_keys))

def make_rsc_job(target, dep_targets):
return Job(
key=self._rsc_key_for_target(target),
Expand All @@ -432,7 +430,7 @@ def make_rsc_job(target, dep_targets):
),
# The rsc jobs depend on other rsc jobs, and on zinc jobs for targets that are not
# processed by rsc.
dependencies=list(all_zinc_rsc_invalid_dep_keys(dep_targets)),
dependencies=[cache_doublecheck_key] + list(all_zinc_rsc_invalid_dep_keys(dep_targets)),
size=self._size_estimator(rsc_compile_context.sources),
)

Expand All @@ -453,7 +451,7 @@ def make_zinc_job(target, input_product_key, output_products, dep_keys):
counter,
compile_contexts,
CompositeProductAdder(*output_products)),
dependencies=list(dep_keys),
dependencies=[cache_doublecheck_key] + list(dep_keys),
size=self._size_estimator(zinc_compile_context.sources),
)

Expand All @@ -470,6 +468,19 @@ def record(k, v):
record('workflow', workflow.value)
record('execution_strategy', self.execution_strategy)

# Create the cache doublecheck job.
workflow.resolve_for_enum_variant({
'zinc-only': lambda: cache_doublecheck_jobs.append(
make_cache_doublecheck_job(list(all_zinc_rsc_invalid_dep_keys(invalid_dependencies)))
),
'zinc-java': lambda: cache_doublecheck_jobs.append(
make_cache_doublecheck_job(list(only_zinc_invalid_dep_keys(invalid_dependencies)))
),
'rsc-and-zinc': lambda: cache_doublecheck_jobs.append(
make_cache_doublecheck_job(list(all_zinc_rsc_invalid_dep_keys(invalid_dependencies)))
),
})()

# Create the rsc job.
# Currently, rsc only supports outlining scala.
workflow.resolve_for_enum_variant({
Expand Down Expand Up @@ -519,25 +530,19 @@ def record(k, v):
)),
})()

all_jobs = rsc_jobs + zinc_jobs

if all_jobs:
write_to_cache_job = Job(
key=self._write_to_cache_key_for_target(compile_target),
fn=functools.partial(
work_for_vts_write_to_cache,
ivts,
rsc_compile_context,
),
dependencies=[job.key for job in all_jobs],
run_asap=True,
# If compilation and analysis work succeeds, validate the vts.
# Otherwise, fail it.
on_success=ivts.update,
on_failure=ivts.force_invalidate)
all_jobs.append(write_to_cache_job)

return all_jobs
compile_jobs = rsc_jobs + zinc_jobs

# Create a job that depends on all real work having completed that will eagerly write to the
# cache by calling `vt.update()`.
write_to_cache_job = Job(
key=self._write_to_cache_key_for_target(compile_target),
fn=ivts.update,
dependencies=[job.key for job in compile_jobs],
run_asap=True,
on_failure=ivts.force_invalidate)

all_jobs = cache_doublecheck_jobs + rsc_jobs + zinc_jobs + [write_to_cache_job]
return (all_jobs, len(compile_jobs))

class RscZincMergedCompileContexts(datatype([
('rsc_cc', RscCompileContext),
Expand Down
Loading