Skip to content

Commit

Permalink
[BEAM-14334] Fix leakage of SparkContext in Spark runner tests to rem…
Browse files Browse the repository at this point in the history
…ove forkEvery 1 (apache#17406)

* [BEAM-14334] Fix leakage of SparkContext in Spark runner tests to remove forkEvery 1 and set provided SparkContext via SparkContextFactory to avoid losing it during a serde roundtrip in TestPipenline.
  • Loading branch information
Moritz Mack authored May 12, 2022
1 parent a167424 commit a6ee885
Show file tree
Hide file tree
Showing 20 changed files with 474 additions and 372 deletions.
14 changes: 3 additions & 11 deletions runners/spark/spark_runner.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ if (copySourceBase) {
}

test {
systemProperty "beam.spark.test.reuseSparkContext", "true"
systemProperty "spark.sql.shuffle.partitions", "4"
systemProperty "spark.ui.enabled", "false"
systemProperty "spark.ui.showConsoleProgress", "false"
Expand All @@ -113,17 +112,14 @@ test {
jvmArgs System.getProperty("beamSurefireArgline")
}

// Only one SparkContext may be running in a JVM (SPARK-2243)
forkEvery 1
maxParallelForks 4
useJUnit {
excludeCategories "org.apache.beam.runners.spark.StreamingTest"
excludeCategories "org.apache.beam.runners.spark.UsesCheckpointRecovery"
}
filter {
// BEAM-11653 MetricsSinkTest is failing with Spark 3
excludeTestsMatching 'org.apache.beam.runners.spark.aggregators.metrics.sink.SparkMetricsSinkTest'
}

// easily re-run all tests (to deal with flaky tests / SparkContext leaks)
if(project.hasProperty("rerun-tests")) { outputs.upToDateWhen {false} }
}

dependencies {
Expand Down Expand Up @@ -291,10 +287,6 @@ def validatesRunnerStreaming = tasks.register("validatesRunnerStreaming", Test)
useJUnit {
includeCategories 'org.apache.beam.runners.spark.StreamingTest'
}
filter {
// BEAM-11653 MetricsSinkTest is failing with Spark 3
excludeTestsMatching 'org.apache.beam.runners.spark.aggregators.metrics.sink.SparkMetricsSinkTest'
}
}

tasks.register("validatesStructuredStreamingRunnerBatch", Test) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
* which link to Spark dependencies, won't be scanned by {@link PipelineOptions} reflective
* instantiation. Note that {@link SparkContextOptions} is not registered with {@link
* SparkRunnerRegistrar}.
*
* <p>Note: It's recommended to use {@link
* org.apache.beam.runners.spark.translation.SparkContextFactory#setProvidedSparkContext(JavaSparkContext)}
* instead of {@link SparkContextOptions#setProvidedSparkContext(JavaSparkContext)} for testing.
* When using @{@link org.apache.beam.sdk.testing.TestPipeline} any provided {@link
* JavaSparkContext} via {@link SparkContextOptions} is dropped.
*/
public interface SparkContextOptions extends SparkPipelineOptions {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.concurrent.TimeoutException;
import org.apache.beam.runners.core.construction.SplittableParDo;
import org.apache.beam.runners.spark.translation.EvaluationContext;
import org.apache.beam.runners.spark.translation.SparkContextFactory;
import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
import org.apache.beam.runners.spark.translation.TransformTranslator;
import org.apache.beam.runners.spark.translation.streaming.StreamingTransformTranslator;
Expand Down Expand Up @@ -86,7 +87,8 @@ public SparkPipelineResult run(Pipeline pipeline) {
SplittableParDo.convertReadBasedSplittableDoFnsToPrimitiveReadsIfNecessary(pipeline);
}

JavaSparkContext jsc = new JavaSparkContext("local[1]", "Debug_Pipeline");
JavaSparkContext jsc =
SparkContextFactory.getSparkContext(pipeline.getOptions().as(SparkPipelineOptions.class));
JavaStreamingContext jssc =
new JavaStreamingContext(jsc, new org.apache.spark.streaming.Duration(1000));

Expand All @@ -107,7 +109,7 @@ public SparkPipelineResult run(Pipeline pipeline) {

pipeline.traverseTopologically(visitor);

jsc.stop();
SparkContextFactory.stopSparkContext(jsc);

String debugString = visitor.getDebugString();
LOG.info("Translated Native Spark pipeline:\n" + debugString);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
*/
package org.apache.beam.runners.spark.translation;

import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;

import javax.annotation.Nullable;
import org.apache.beam.runners.spark.SparkContextOptions;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.coders.SparkRunnerKryoRegistrator;
Expand All @@ -25,80 +28,121 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** The Spark context factory. */
@SuppressWarnings({
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public final class SparkContextFactory {
private static final Logger LOG = LoggerFactory.getLogger(SparkContextFactory.class);

/**
* If the property {@code beam.spark.test.reuseSparkContext} is set to {@code true} then the Spark
* context will be reused for beam pipelines. This property should only be enabled for tests.
*
* @deprecated This will leak your SparkContext, any attempt to create a new SparkContext later
* will fail. Please use {@link #setProvidedSparkContext(JavaSparkContext)} / {@link
* #clearProvidedSparkContext()} instead to properly control the lifecycle of your context.
* Alternatively you may also provide a SparkContext using {@link
* SparkContextOptions#setUsesProvidedSparkContext(boolean)} together with {@link
* SparkContextOptions#setProvidedSparkContext(JavaSparkContext)} and close that one
* appropriately. Tests of this module should use {@code SparkContextRule}.
*/
@Deprecated
public static final String TEST_REUSE_SPARK_CONTEXT = "beam.spark.test.reuseSparkContext";

// Spark allows only one context for JVM so this can be static.
private static JavaSparkContext sparkContext;
private static String sparkMaster;
private static boolean usesProvidedSparkContext;
private static @Nullable JavaSparkContext sparkContext;

// Remember spark master if TEST_REUSE_SPARK_CONTEXT is enabled.
private static @Nullable String reusableSparkMaster;

// SparkContext is provided by the user instead of simply reused using TEST_REUSE_SPARK_CONTEXT
private static boolean hasProvidedSparkContext;

private SparkContextFactory() {}

/**
* Set an externally managed {@link JavaSparkContext} that will be used if {@link
* SparkContextOptions#getUsesProvidedSparkContext()} is set to {@code true}.
*
* <p>A Spark context can also be provided using {@link
* SparkContextOptions#setProvidedSparkContext(JavaSparkContext)}. However, it will be dropped
* during serialization potentially leading to confusing behavior. This is particularly the case
* when used in tests with {@link org.apache.beam.sdk.testing.TestPipeline}.
*/
public static synchronized void setProvidedSparkContext(JavaSparkContext providedSparkContext) {
sparkContext = checkNotNull(providedSparkContext);
hasProvidedSparkContext = true;
reusableSparkMaster = null;
}

public static synchronized void clearProvidedSparkContext() {
hasProvidedSparkContext = false;
sparkContext = null;
}

public static synchronized JavaSparkContext getSparkContext(SparkPipelineOptions options) {
SparkContextOptions contextOptions = options.as(SparkContextOptions.class);
usesProvidedSparkContext = contextOptions.getUsesProvidedSparkContext();
// reuse should be ignored if the context is provided.
if (Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT) && !usesProvidedSparkContext) {

// if the context is null or stopped for some reason, re-create it.
if (sparkContext == null || sparkContext.sc().isStopped()) {
sparkContext = createSparkContext(contextOptions);
sparkMaster = options.getSparkMaster();
} else if (!options.getSparkMaster().equals(sparkMaster)) {
throw new IllegalArgumentException(
if (contextOptions.getUsesProvidedSparkContext()) {
JavaSparkContext jsc = contextOptions.getProvidedSparkContext();
if (jsc != null) {
setProvidedSparkContext(jsc);
} else if (hasProvidedSparkContext) {
jsc = sparkContext;
}
if (jsc == null) {
throw new IllegalStateException(
"No Spark context was provided. Use SparkContextFactor.setProvidedSparkContext to do so.");
} else if (jsc.sc().isStopped()) {
LOG.error("The provided Spark context " + jsc + " was already stopped.");
throw new IllegalStateException("The provided Spark context was already stopped");
}
LOG.info("Using a provided Spark Context");
return jsc;
} else if (Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT)) {
// This is highly discouraged as it leaks the SparkContext without any way to close it.
// Attempting to create any new SparkContext later will fail.
// If the context is null or stopped for some reason, re-create it.
@Nullable JavaSparkContext jsc = sparkContext;
if (jsc == null || jsc.sc().isStopped()) {
sparkContext = jsc = createSparkContext(contextOptions);
reusableSparkMaster = options.getSparkMaster();
hasProvidedSparkContext = false;
} else if (hasProvidedSparkContext) {
throw new IllegalStateException(
"Usage of provided Spark context is disabled in SparkPipelineOptions.");
} else if (!options.getSparkMaster().equals(reusableSparkMaster)) {
throw new IllegalStateException(
String.format(
"Cannot reuse spark context "
+ "with different spark master URL. Existing: %s, requested: %s.",
sparkMaster, options.getSparkMaster()));
reusableSparkMaster, options.getSparkMaster()));
}
return sparkContext;
return jsc;
} else {
return createSparkContext(contextOptions);
JavaSparkContext jsc = createSparkContext(contextOptions);
clearProvidedSparkContext(); // any provided context can't be valid anymore
return jsc;
}
}

public static synchronized void stopSparkContext(JavaSparkContext context) {
if (!Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT) && !usesProvidedSparkContext) {
if (!Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT) && !hasProvidedSparkContext) {
context.stop();
}
}

private static JavaSparkContext createSparkContext(SparkContextOptions contextOptions) {
if (usesProvidedSparkContext) {
LOG.info("Using a provided Spark Context");
JavaSparkContext jsc = contextOptions.getProvidedSparkContext();
if (jsc == null || jsc.sc().isStopped()) {
LOG.error("The provided Spark context " + jsc + " was not created or was stopped");
throw new RuntimeException("The provided Spark context was not created or was stopped");
}
return jsc;
} else {
LOG.info("Creating a brand new Spark Context.");
SparkConf conf = new SparkConf();
if (!conf.contains("spark.master")) {
// set master if not set.
conf.setMaster(contextOptions.getSparkMaster());
}

if (contextOptions.getFilesToStage() != null && !contextOptions.getFilesToStage().isEmpty()) {
conf.setJars(contextOptions.getFilesToStage().toArray(new String[0]));
}
private static JavaSparkContext createSparkContext(SparkPipelineOptions options) {
LOG.info("Creating a brand new Spark Context.");
SparkConf conf = new SparkConf();
if (!conf.contains("spark.master")) {
// set master if not set.
conf.setMaster(options.getSparkMaster());
}

conf.setAppName(contextOptions.getAppName());
// register immutable collections serializers because the SDK uses them.
conf.set("spark.kryo.registrator", SparkRunnerKryoRegistrator.class.getName());
return new JavaSparkContext(conf);
if (options.getFilesToStage() != null && !options.getFilesToStage().isEmpty()) {
conf.setJars(options.getFilesToStage().toArray(new String[0]));
}

conf.setAppName(options.getAppName());
// register immutable collections serializers because the SDK uses them.
conf.set("spark.kryo.registrator", SparkRunnerKryoRegistrator.class.getName());
return new JavaSparkContext(conf);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
import java.util.List;
import org.apache.beam.runners.spark.translation.Dataset;
import org.apache.beam.runners.spark.translation.EvaluationContext;
import org.apache.beam.runners.spark.translation.SparkContextFactory;
import org.apache.beam.runners.spark.translation.TransformTranslator;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.Create.Values;
Expand All @@ -39,7 +37,7 @@
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.ClassRule;
import org.junit.Test;

/** Tests of {@link Dataset#cache(String, Coder)}} scenarios. */
Expand All @@ -48,13 +46,15 @@
})
public class CacheTest {

@ClassRule public static SparkContextRule contextRule = new SparkContextRule();

/**
* Test checks how the cache candidates map is populated by the runner when evaluating the
* pipeline.
*/
@Test
public void cacheCandidatesUpdaterTest() {
SparkPipelineOptions options = createOptions();
SparkPipelineOptions options = contextRule.createPipelineOptions();
Pipeline pipeline = Pipeline.create(options);
PCollection<String> pCollection = pipeline.apply(Create.of("foo", "bar"));

Expand All @@ -80,8 +80,8 @@ public void processElement(ProcessContext processContext) {
})
.withSideInputs(view));

JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
EvaluationContext ctxt = new EvaluationContext(jsc, pipeline, options);
EvaluationContext ctxt =
new EvaluationContext(contextRule.getSparkContext(), pipeline, options);
SparkRunner.CacheVisitor cacheVisitor =
new SparkRunner.CacheVisitor(new TransformTranslator.Translator(), ctxt);
pipeline.traverseTopologically(cacheVisitor);
Expand All @@ -91,15 +91,15 @@ public void processElement(ProcessContext processContext) {

@Test
public void shouldCacheTest() {
SparkPipelineOptions options = createOptions();
SparkPipelineOptions options = contextRule.createPipelineOptions();
options.setCacheDisabled(true);
Pipeline pipeline = Pipeline.create(options);

Values<String> valuesTransform = Create.of("foo", "bar");
PCollection pCollection = mock(PCollection.class);

JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
EvaluationContext ctxt = new EvaluationContext(jsc, pipeline, options);
EvaluationContext ctxt =
new EvaluationContext(contextRule.getSparkContext(), pipeline, options);
ctxt.getCacheCandidates().put(pCollection, 2L);

assertFalse(ctxt.shouldCache(valuesTransform, pCollection));
Expand All @@ -110,11 +110,4 @@ public void shouldCacheTest() {
GroupByKey<String, String> gbkTransform = GroupByKey.create();
assertFalse(ctxt.shouldCache(gbkTransform, pCollection));
}

private SparkPipelineOptions createOptions() {
SparkPipelineOptions options =
PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
options.setRunner(TestSparkRunner.class);
return options;
}
}
Loading

0 comments on commit a6ee885

Please sign in to comment.