ev1) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
new file mode 100644
index 0000000000000..4ee6a82c0423e
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+/**
+ * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
+ *
+ * Within the long, the data is laid out as follows:
+ *
+ * [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ *
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
+ * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
+ * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
+ * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
+ *
+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
+ * optimization to future work as it will require more careful design to ensure that addresses are
+ * properly aligned (e.g. by padding records).
+ */
+final class PackedRecordPointer {
+
+ static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes
+
+ /**
+ * The maximum partition identifier that can be encoded. Note that partition ids start from 0.
+ */
+ static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
+
+ /** Bit mask for the lower 40 bits of a long. */
+ private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
+
+ /** Bit mask for the upper 24 bits of a long */
+ private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS;
+
+ /** Bit mask for the lower 27 bits of a long. */
+ private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1;
+
+ /** Bit mask for the lower 51 bits of a long. */
+ private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1;
+
+ /** Bit mask for the upper 13 bits of a long */
+ private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
+
+ /**
+ * Pack a record address and partition id into a single word.
+ *
+ * @param recordPointer a record pointer encoded by TaskMemoryManager.
+ * @param partitionId a shuffle partition id (maximum value of 2^24).
+ * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class.
+ */
+ public static long packPointer(long recordPointer, int partitionId) {
+ assert (partitionId <= MAXIMUM_PARTITION_ID);
+ // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page.
+ // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses.
+ final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24;
+ final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS);
+ return (((long) partitionId) << 40) | compressedAddress;
+ }
+
+ private long packedRecordPointer;
+
+ public void set(long packedRecordPointer) {
+ this.packedRecordPointer = packedRecordPointer;
+ }
+
+ public int getPartitionId() {
+ return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
+ }
+
+ public long getRecordPointer() {
+ final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS;
+ final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS;
+ return pageNumber | offsetInPage;
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
new file mode 100644
index 0000000000000..7bac0dc0bbeb6
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.File;
+
+import org.apache.spark.storage.TempShuffleBlockId;
+
+/**
+ * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}.
+ */
+final class SpillInfo {
+ final long[] partitionLengths;
+ final File file;
+ final TempShuffleBlockId blockId;
+
+ public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
+ this.partitionLengths = new long[numPartitions];
+ this.file = file;
+ this.blockId = blockId;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
new file mode 100644
index 0000000000000..9e9ed94b7890c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.LinkedList;
+
+import scala.Tuple2;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+/**
+ * An external sorter that is specialized for sort-based shuffle.
+ *
+ * Incoming records are appended to data pages. When all records have been inserted (or when the
+ * current thread's shuffle memory limit is reached), the in-memory records are sorted according to
+ * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
+ * written to a single output file (or multiple files, if we've spilled). The format of the output
+ * files is the same as the format of the final output file written by
+ * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
+ * written as a single serialized, compressed stream that can be read with a new decompression and
+ * deserialization stream.
+ *
+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its
+ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
+ * specialized merge procedure that avoids extra serialization/deserialization.
+ */
+final class UnsafeShuffleExternalSorter {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
+
+ private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
+ @VisibleForTesting
+ static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+ @VisibleForTesting
+ static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
+
+ private final int initialSize;
+ private final int numPartitions;
+ private final TaskMemoryManager memoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final BlockManager blockManager;
+ private final TaskContext taskContext;
+ private final ShuffleWriteMetrics writeMetrics;
+
+ /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+ private final int fileBufferSizeBytes;
+
+ /**
+ * Memory pages that hold the records being sorted. The pages in this list are freed when
+ * spilling, although in principle we could recycle these pages across spills (on the other hand,
+ * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+ * itself).
+ */
+ private final LinkedList allocatedPages = new LinkedList();
+
+ private final LinkedList spills = new LinkedList();
+
+ // These variables are reset after spilling:
+ private UnsafeShuffleInMemorySorter sorter;
+ private MemoryBlock currentPage = null;
+ private long currentPagePosition = -1;
+ private long freeSpaceInCurrentPage = 0;
+
+ public UnsafeShuffleExternalSorter(
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ int initialSize,
+ int numPartitions,
+ SparkConf conf,
+ ShuffleWriteMetrics writeMetrics) throws IOException {
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.blockManager = blockManager;
+ this.taskContext = taskContext;
+ this.initialSize = initialSize;
+ this.numPartitions = numPartitions;
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+
+ this.writeMetrics = writeMetrics;
+ initializeForWriting();
+ }
+
+ /**
+ * Allocates new sort data structures. Called when creating the sorter and after each spill.
+ */
+ private void initializeForWriting() throws IOException {
+ // TODO: move this sizing calculation logic into a static method of sorter:
+ final long memoryRequested = initialSize * 8L;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryAcquired != memoryRequested) {
+ shuffleMemoryManager.release(memoryAcquired);
+ throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+ }
+
+ this.sorter = new UnsafeShuffleInMemorySorter(initialSize);
+ }
+
+ /**
+ * Sorts the in-memory records and writes the sorted records to an on-disk file.
+ * This method does not free the sort data structures.
+ *
+ * @param isLastFile if true, this indicates that we're writing the final output file and that the
+ * bytes written should be counted towards shuffle spill metrics rather than
+ * shuffle write metrics.
+ */
+ private void writeSortedFile(boolean isLastFile) throws IOException {
+
+ final ShuffleWriteMetrics writeMetricsToUse;
+
+ if (isLastFile) {
+ // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
+ writeMetricsToUse = writeMetrics;
+ } else {
+ // We're spilling, so bytes written should be counted towards spill rather than write.
+ // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
+ // them towards shuffle bytes written.
+ writeMetricsToUse = new ShuffleWriteMetrics();
+ }
+
+ // This call performs the actual sort.
+ final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
+ sorter.getSortedIterator();
+
+ // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
+ // after SPARK-5581 is fixed.
+ BlockObjectWriter writer;
+
+ // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+ // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+ // data through a byte array. This array does not need to be large enough to hold a single
+ // record;
+ final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+ // Because this output will be read during shuffle, its compression codec must be controlled by
+ // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+ // createTempShuffleBlock here; see SPARK-3426 for more details.
+ final Tuple2 spilledFileInfo =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = spilledFileInfo._2();
+ final TempShuffleBlockId blockId = spilledFileInfo._1();
+ final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
+
+ // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+ // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ // around this, we pass a dummy no-op serializer.
+ final SerializerInstance ser = DummySerializerInstance.INSTANCE;
+
+ writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+
+ int currentPartition = -1;
+ while (sortedRecords.hasNext()) {
+ sortedRecords.loadNext();
+ final int partition = sortedRecords.packedRecordPointer.getPartitionId();
+ assert (partition >= currentPartition);
+ if (partition != currentPartition) {
+ // Switch to the new partition
+ if (currentPartition != -1) {
+ writer.commitAndClose();
+ spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+ }
+ currentPartition = partition;
+ writer =
+ blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+ }
+
+ final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
+ final Object recordPage = memoryManager.getPage(recordPointer);
+ final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer);
+ int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage);
+ long recordReadPosition = recordOffsetInPage + 4; // skip over record length
+ while (dataRemaining > 0) {
+ final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
+ PlatformDependent.copyMemory(
+ recordPage,
+ recordReadPosition,
+ writeBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ toTransfer);
+ writer.write(writeBuffer, 0, toTransfer);
+ recordReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ }
+ writer.recordWritten();
+ }
+
+ if (writer != null) {
+ writer.commitAndClose();
+ // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
+ // then the file might be empty. Note that it might be better to avoid calling
+ // writeSortedFile() in that case.
+ if (currentPartition != -1) {
+ spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+ spills.add(spillInfo);
+ }
+ }
+
+ if (!isLastFile) { // i.e. this is a spill file
+ // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
+ // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
+ // relies on its `recordWritten()` method being called in order to trigger periodic updates to
+ // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
+ // counter at a higher-level, then the in-progress metrics for records written and bytes
+ // written would get out of sync.
+ //
+ // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
+ // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
+ // metrics to the true write metrics here. The reason for performing this copying is so that
+ // we can avoid reporting spilled bytes as shuffle write bytes.
+ //
+ // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
+ // Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
+ // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
+ writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten());
+ taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten());
+ }
+ }
+
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
+ @VisibleForTesting
+ void spill() throws IOException {
+ logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+ Thread.currentThread().getId(),
+ Utils.bytesToString(getMemoryUsage()),
+ spills.size(),
+ spills.size() > 1 ? " times" : " time");
+
+ writeSortedFile(false);
+ final long sorterMemoryUsage = sorter.getMemoryUsage();
+ sorter = null;
+ shuffleMemoryManager.release(sorterMemoryUsage);
+ final long spillSize = freeMemory();
+ taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+
+ initializeForWriting();
+ }
+
+ private long getMemoryUsage() {
+ return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
+ }
+
+ private long freeMemory() {
+ long memoryFreed = 0;
+ for (MemoryBlock block : allocatedPages) {
+ memoryManager.freePage(block);
+ shuffleMemoryManager.release(block.size());
+ memoryFreed += block.size();
+ }
+ allocatedPages.clear();
+ currentPage = null;
+ currentPagePosition = -1;
+ freeSpaceInCurrentPage = 0;
+ return memoryFreed;
+ }
+
+ /**
+ * Force all memory and spill files to be deleted; called by shuffle error-handling code.
+ */
+ public void cleanupAfterError() {
+ freeMemory();
+ for (SpillInfo spill : spills) {
+ if (spill.file.exists() && !spill.file.delete()) {
+ logger.error("Unable to delete spill file {}", spill.file.getPath());
+ }
+ }
+ if (sorter != null) {
+ shuffleMemoryManager.release(sorter.getMemoryUsage());
+ sorter = null;
+ }
+ }
+
+ /**
+ * Checks whether there is enough space to insert a new record into the sorter.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+
+ * @return true if the record can be inserted without requiring more allocations, false otherwise.
+ */
+ private boolean haveSpaceForRecord(int requiredSpace) {
+ assert (requiredSpace > 0);
+ return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+ }
+
+ /**
+ * Allocates more memory in order to insert an additional record. This will request additional
+ * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+ * obtained.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+ */
+ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+ if (!sorter.hasSpaceForAnotherRecord()) {
+ logger.debug("Attempting to expand sort pointer array");
+ final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+ final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+ if (memoryAcquired < memoryToGrowPointerArray) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ } else {
+ sorter.expandPointerArray();
+ shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+ }
+ }
+ if (requiredSpace > freeSpaceInCurrentPage) {
+ logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
+ freeSpaceInCurrentPage);
+ // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
+ // without using the free space at the end of the current page. We should also do this for
+ // BytesToBytesMap.
+ if (requiredSpace > PAGE_SIZE) {
+ throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
+ PAGE_SIZE + ")");
+ } else {
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquired < PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
+ }
+ }
+ currentPage = memoryManager.allocatePage(PAGE_SIZE);
+ currentPagePosition = currentPage.getBaseOffset();
+ freeSpaceInCurrentPage = PAGE_SIZE;
+ allocatedPages.add(currentPage);
+ }
+ }
+ }
+
+ /**
+ * Write a record to the shuffle sorter.
+ */
+ public void insertRecord(
+ Object recordBaseObject,
+ long recordBaseOffset,
+ int lengthInBytes,
+ int partitionId) throws IOException {
+ // Need 4 bytes to store the record length.
+ final int totalSpaceRequired = lengthInBytes + 4;
+ if (!haveSpaceForRecord(totalSpaceRequired)) {
+ allocateSpaceForRecord(totalSpaceRequired);
+ }
+
+ final long recordAddress =
+ memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
+ final Object dataPageBaseObject = currentPage.getBaseObject();
+ PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
+ currentPagePosition += 4;
+ freeSpaceInCurrentPage -= 4;
+ PlatformDependent.copyMemory(
+ recordBaseObject,
+ recordBaseOffset,
+ dataPageBaseObject,
+ currentPagePosition,
+ lengthInBytes);
+ currentPagePosition += lengthInBytes;
+ freeSpaceInCurrentPage -= lengthInBytes;
+ sorter.insertRecord(recordAddress, partitionId);
+ }
+
+ /**
+ * Close the sorter, causing any buffered data to be sorted and written out to disk.
+ *
+ * @return metadata for the spill files written by this sorter. If no records were ever inserted
+ * into this sorter, then this will return an empty array.
+ * @throws IOException
+ */
+ public SpillInfo[] closeAndGetSpills() throws IOException {
+ try {
+ if (sorter != null) {
+ // Do not count the final file towards the spill count.
+ writeSortedFile(true);
+ freeMemory();
+ }
+ return spills.toArray(new SpillInfo[spills.size()]);
+ } catch (IOException e) {
+ cleanupAfterError();
+ throw e;
+ }
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
new file mode 100644
index 0000000000000..5bab501da9364
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.util.Comparator;
+
+import org.apache.spark.util.collection.Sorter;
+
+final class UnsafeShuffleInMemorySorter {
+
+ private final Sorter sorter;
+ private static final class SortComparator implements Comparator {
+ @Override
+ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
+ return left.getPartitionId() - right.getPartitionId();
+ }
+ }
+ private static final SortComparator SORT_COMPARATOR = new SortComparator();
+
+ /**
+ * An array of record pointers and partition ids that have been encoded by
+ * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
+ * records.
+ */
+ private long[] pointerArray;
+
+ /**
+ * The position in the pointer array where new records can be inserted.
+ */
+ private int pointerArrayInsertPosition = 0;
+
+ public UnsafeShuffleInMemorySorter(int initialSize) {
+ assert (initialSize > 0);
+ this.pointerArray = new long[initialSize];
+ this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE);
+ }
+
+ public void expandPointerArray() {
+ final long[] oldArray = pointerArray;
+ // Guard against overflow:
+ final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+ pointerArray = new long[newLength];
+ System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+ }
+
+ public boolean hasSpaceForAnotherRecord() {
+ return pointerArrayInsertPosition + 1 < pointerArray.length;
+ }
+
+ public long getMemoryUsage() {
+ return pointerArray.length * 8L;
+ }
+
+ /**
+ * Inserts a record to be sorted.
+ *
+ * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to
+ * certain pointer compression techniques used by the sorter, the sort can
+ * only operate on pointers that point to locations in the first
+ * {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page.
+ * @param partitionId the partition id, which must be less than or equal to
+ * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}.
+ */
+ public void insertRecord(long recordPointer, int partitionId) {
+ if (!hasSpaceForAnotherRecord()) {
+ if (pointerArray.length == Integer.MAX_VALUE) {
+ throw new IllegalStateException("Sort pointer array has reached maximum size");
+ } else {
+ expandPointerArray();
+ }
+ }
+ pointerArray[pointerArrayInsertPosition] =
+ PackedRecordPointer.packPointer(recordPointer, partitionId);
+ pointerArrayInsertPosition++;
+ }
+
+ /**
+ * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
+ */
+ public static final class UnsafeShuffleSorterIterator {
+
+ private final long[] pointerArray;
+ private final int numRecords;
+ final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
+ private int position = 0;
+
+ public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
+ this.numRecords = numRecords;
+ this.pointerArray = pointerArray;
+ }
+
+ public boolean hasNext() {
+ return position < numRecords;
+ }
+
+ public void loadNext() {
+ packedRecordPointer.set(pointerArray[position]);
+ position++;
+ }
+ }
+
+ /**
+ * Return an iterator over record pointers in sorted order.
+ */
+ public UnsafeShuffleSorterIterator getSortedIterator() {
+ sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
+ return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
new file mode 100644
index 0000000000000..a66d74ee44782
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+final class UnsafeShuffleSortDataFormat extends SortDataFormat {
+
+ public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
+
+ private UnsafeShuffleSortDataFormat() { }
+
+ @Override
+ public PackedRecordPointer getKey(long[] data, int pos) {
+ // Since we re-use keys, this method shouldn't be called.
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public PackedRecordPointer newKey() {
+ return new PackedRecordPointer();
+ }
+
+ @Override
+ public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
+ reuse.set(data[pos]);
+ return reuse;
+ }
+
+ @Override
+ public void swap(long[] data, int pos0, int pos1) {
+ final long temp = data[pos0];
+ data[pos0] = data[pos1];
+ data[pos1] = temp;
+ }
+
+ @Override
+ public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+ dst[dstPos] = src[srcPos];
+ }
+
+ @Override
+ public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+ System.arraycopy(src, srcPos, dst, dstPos, length);
+ }
+
+ @Override
+ public long[] allocate(int length) {
+ return new long[length];
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
new file mode 100644
index 0000000000000..ad7eb04afcd8c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -0,0 +1,438 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.*;
+import java.nio.channels.FileChannel;
+import java.util.Iterator;
+import javax.annotation.Nullable;
+
+import scala.Option;
+import scala.Product2;
+import scala.collection.JavaConversions;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.*;
+import org.apache.spark.annotation.Private;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.TimeTrackingOutputStream;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+@Private
+public class UnsafeShuffleWriter extends ShuffleWriter {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
+
+ private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
+
+ @VisibleForTesting
+ static final int INITIAL_SORT_BUFFER_SIZE = 4096;
+
+ private final BlockManager blockManager;
+ private final IndexShuffleBlockResolver shuffleBlockResolver;
+ private final TaskMemoryManager memoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final SerializerInstance serializer;
+ private final Partitioner partitioner;
+ private final ShuffleWriteMetrics writeMetrics;
+ private final int shuffleId;
+ private final int mapId;
+ private final TaskContext taskContext;
+ private final SparkConf sparkConf;
+ private final boolean transferToEnabled;
+
+ private MapStatus mapStatus = null;
+ private UnsafeShuffleExternalSorter sorter = null;
+
+ /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
+ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
+ public MyByteArrayOutputStream(int size) { super(size); }
+ public byte[] getBuf() { return buf; }
+ }
+
+ private MyByteArrayOutputStream serBuffer;
+ private SerializationStream serOutputStream;
+
+ /**
+ * Are we in the process of stopping? Because map tasks can call stop() with success = true
+ * and then call stop() with success = false if they get an exception, we want to make sure
+ * we don't try deleting files, etc twice.
+ */
+ private boolean stopping = false;
+
+ public UnsafeShuffleWriter(
+ BlockManager blockManager,
+ IndexShuffleBlockResolver shuffleBlockResolver,
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ UnsafeShuffleHandle handle,
+ int mapId,
+ TaskContext taskContext,
+ SparkConf sparkConf) throws IOException {
+ final int numPartitions = handle.dependency().partitioner().numPartitions();
+ if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) {
+ throw new IllegalArgumentException(
+ "UnsafeShuffleWriter can only be used for shuffles with at most " +
+ UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions");
+ }
+ this.blockManager = blockManager;
+ this.shuffleBlockResolver = shuffleBlockResolver;
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.mapId = mapId;
+ final ShuffleDependency dep = handle.dependency();
+ this.shuffleId = dep.shuffleId();
+ this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
+ this.partitioner = dep.partitioner();
+ this.writeMetrics = new ShuffleWriteMetrics();
+ taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+ this.taskContext = taskContext;
+ this.sparkConf = sparkConf;
+ this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
+ open();
+ }
+
+ /**
+ * This convenience method should only be called in test code.
+ */
+ @VisibleForTesting
+ public void write(Iterator> records) throws IOException {
+ write(JavaConversions.asScalaIterator(records));
+ }
+
+ @Override
+ public void write(scala.collection.Iterator> records) throws IOException {
+ boolean success = false;
+ try {
+ while (records.hasNext()) {
+ insertRecordIntoSorter(records.next());
+ }
+ closeAndWriteOutput();
+ success = true;
+ } finally {
+ if (!success) {
+ sorter.cleanupAfterError();
+ }
+ }
+ }
+
+ private void open() throws IOException {
+ assert (sorter == null);
+ sorter = new UnsafeShuffleExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ INITIAL_SORT_BUFFER_SIZE,
+ partitioner.numPartitions(),
+ sparkConf,
+ writeMetrics);
+ serBuffer = new MyByteArrayOutputStream(1024 * 1024);
+ serOutputStream = serializer.serializeStream(serBuffer);
+ }
+
+ @VisibleForTesting
+ void closeAndWriteOutput() throws IOException {
+ serBuffer = null;
+ serOutputStream = null;
+ final SpillInfo[] spills = sorter.closeAndGetSpills();
+ sorter = null;
+ final long[] partitionLengths;
+ try {
+ partitionLengths = mergeSpills(spills);
+ } finally {
+ for (SpillInfo spill : spills) {
+ if (spill.file.exists() && ! spill.file.delete()) {
+ logger.error("Error while deleting spill file {}", spill.file.getPath());
+ }
+ }
+ }
+ shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ }
+
+ @VisibleForTesting
+ void insertRecordIntoSorter(Product2 record) throws IOException {
+ final K key = record._1();
+ final int partitionId = partitioner.getPartition(key);
+ serBuffer.reset();
+ serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
+ serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
+ serOutputStream.flush();
+
+ final int serializedRecordSize = serBuffer.size();
+ assert (serializedRecordSize > 0);
+
+ sorter.insertRecord(
+ serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
+ }
+
+ @VisibleForTesting
+ void forceSorterToSpill() throws IOException {
+ assert (sorter != null);
+ sorter.spill();
+ }
+
+ /**
+ * Merge zero or more spill files together, choosing the fastest merging strategy based on the
+ * number of spills and the IO compression codec.
+ *
+ * @return the partition lengths in the merged file.
+ */
+ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
+ final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+ final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
+ final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
+ final boolean fastMergeEnabled =
+ sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
+ final boolean fastMergeIsSupported =
+ !compressionEnabled || compressionCodec instanceof LZFCompressionCodec;
+ try {
+ if (spills.length == 0) {
+ new FileOutputStream(outputFile).close(); // Create an empty file
+ return new long[partitioner.numPartitions()];
+ } else if (spills.length == 1) {
+ // Here, we don't need to perform any metrics updates because the bytes written to this
+ // output file would have already been counted as shuffle bytes written.
+ Files.move(spills[0].file, outputFile);
+ return spills[0].partitionLengths;
+ } else {
+ final long[] partitionLengths;
+ // There are multiple spills to merge, so none of these spill files' lengths were counted
+ // towards our shuffle write count or shuffle write time. If we use the slow merge path,
+ // then the final output file's size won't necessarily be equal to the sum of the spill
+ // files' sizes. To guard against this case, we look at the output file's actual size when
+ // computing shuffle bytes written.
+ //
+ // We allow the individual merge methods to report their own IO times since different merge
+ // strategies use different IO techniques. We count IO during merge towards the shuffle
+ // shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
+ // branch in ExternalSorter.
+ if (fastMergeEnabled && fastMergeIsSupported) {
+ // Compression is disabled or we are using an IO compression codec that supports
+ // decompression of concatenated compressed streams, so we can perform a fast spill merge
+ // that doesn't need to interpret the spilled bytes.
+ if (transferToEnabled) {
+ logger.debug("Using transferTo-based fast merge");
+ partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
+ } else {
+ logger.debug("Using fileStream-based fast merge");
+ partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
+ }
+ } else {
+ logger.debug("Using slow merge");
+ partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+ }
+ // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
+ // in-memory records, we write out the in-memory records to a file but do not count that
+ // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
+ // to be counted as shuffle write, but this will lead to double-counting of the final
+ // SpillInfo's bytes.
+ writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length());
+ writeMetrics.incShuffleBytesWritten(outputFile.length());
+ return partitionLengths;
+ }
+ } catch (IOException e) {
+ if (outputFile.exists() && !outputFile.delete()) {
+ logger.error("Unable to delete output file {}", outputFile.getPath());
+ }
+ throw e;
+ }
+ }
+
+ /**
+ * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
+ * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
+ * cases where the IO compression codec does not support concatenation of compressed data, or in
+ * cases where users have explicitly disabled use of {@code transferTo} in order to work around
+ * kernel bugs.
+ *
+ * @param spills the spills to merge.
+ * @param outputFile the file to write the merged data to.
+ * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
+ * @return the partition lengths in the merged file.
+ */
+ private long[] mergeSpillsWithFileStream(
+ SpillInfo[] spills,
+ File outputFile,
+ @Nullable CompressionCodec compressionCodec) throws IOException {
+ assert (spills.length >= 2);
+ final int numPartitions = partitioner.numPartitions();
+ final long[] partitionLengths = new long[numPartitions];
+ final InputStream[] spillInputStreams = new FileInputStream[spills.length];
+ OutputStream mergedFileOutputStream = null;
+
+ boolean threwException = true;
+ try {
+ for (int i = 0; i < spills.length; i++) {
+ spillInputStreams[i] = new FileInputStream(spills[i].file);
+ }
+ for (int partition = 0; partition < numPartitions; partition++) {
+ final long initialFileLength = outputFile.length();
+ mergedFileOutputStream =
+ new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
+ if (compressionCodec != null) {
+ mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
+ }
+
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+ if (partitionLengthInSpill > 0) {
+ InputStream partitionInputStream =
+ new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill);
+ if (compressionCodec != null) {
+ partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+ }
+ ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
+ }
+ }
+ mergedFileOutputStream.flush();
+ mergedFileOutputStream.close();
+ partitionLengths[partition] = (outputFile.length() - initialFileLength);
+ }
+ threwException = false;
+ } finally {
+ // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+ // throw exceptions during cleanup if threwException == false.
+ for (InputStream stream : spillInputStreams) {
+ Closeables.close(stream, threwException);
+ }
+ Closeables.close(mergedFileOutputStream, threwException);
+ }
+ return partitionLengths;
+ }
+
+ /**
+ * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
+ * This is only safe when the IO compression codec and serializer support concatenation of
+ * serialized streams.
+ *
+ * @return the partition lengths in the merged file.
+ */
+ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
+ assert (spills.length >= 2);
+ final int numPartitions = partitioner.numPartitions();
+ final long[] partitionLengths = new long[numPartitions];
+ final FileChannel[] spillInputChannels = new FileChannel[spills.length];
+ final long[] spillInputChannelPositions = new long[spills.length];
+ FileChannel mergedFileOutputChannel = null;
+
+ boolean threwException = true;
+ try {
+ for (int i = 0; i < spills.length; i++) {
+ spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
+ }
+ // This file needs to opened in append mode in order to work around a Linux kernel bug that
+ // affects transferTo; see SPARK-3948 for more details.
+ mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
+
+ long bytesWrittenToMergedFile = 0;
+ for (int partition = 0; partition < numPartitions; partition++) {
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+ long bytesToTransfer = partitionLengthInSpill;
+ final FileChannel spillInputChannel = spillInputChannels[i];
+ final long writeStartTime = System.nanoTime();
+ while (bytesToTransfer > 0) {
+ final long actualBytesTransferred = spillInputChannel.transferTo(
+ spillInputChannelPositions[i],
+ bytesToTransfer,
+ mergedFileOutputChannel);
+ spillInputChannelPositions[i] += actualBytesTransferred;
+ bytesToTransfer -= actualBytesTransferred;
+ }
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+ bytesWrittenToMergedFile += partitionLengthInSpill;
+ partitionLengths[partition] += partitionLengthInSpill;
+ }
+ }
+ // Check the position after transferTo loop to see if it is in the right position and raise an
+ // exception if it is incorrect. The position will not be increased to the expected length
+ // after calling transferTo in kernel version 2.6.32. This issue is described at
+ // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
+ if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
+ throw new IOException(
+ "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
+ "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
+ " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
+ "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
+ "to disable this NIO feature."
+ );
+ }
+ threwException = false;
+ } finally {
+ // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+ // throw exceptions during cleanup if threwException == false.
+ for (int i = 0; i < spills.length; i++) {
+ assert(spillInputChannelPositions[i] == spills[i].file.length());
+ Closeables.close(spillInputChannels[i], threwException);
+ }
+ Closeables.close(mergedFileOutputChannel, threwException);
+ }
+ return partitionLengths;
+ }
+
+ @Override
+ public Option stop(boolean success) {
+ try {
+ if (stopping) {
+ return Option.apply(null);
+ } else {
+ stopping = true;
+ if (success) {
+ if (mapStatus == null) {
+ throw new IllegalStateException("Cannot call stop(true) without having called write()");
+ }
+ return Option.apply(mapStatus);
+ } else {
+ // The map task failed, so delete our output data.
+ shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
+ return Option.apply(null);
+ }
+ }
+ } finally {
+ if (sorter != null) {
+ // If sorter is non-null, then this implies that we called stop() in response to an error,
+ // so we need to clean up memory and spill files created by the sorter
+ sorter.cleanupAfterError();
+ }
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
new file mode 100644
index 0000000000000..dc2aa30466cc6
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage;
+
+import java.io.IOException;
+import java.io.OutputStream;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+
+/**
+ * Intercepts write calls and tracks total time spent writing in order to update shuffle write
+ * metrics. Not thread safe.
+ */
+@Private
+public final class TimeTrackingOutputStream extends OutputStream {
+
+ private final ShuffleWriteMetrics writeMetrics;
+ private final OutputStream outputStream;
+
+ public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) {
+ this.writeMetrics = writeMetrics;
+ this.outputStream = outputStream;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.write(b);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+
+ @Override
+ public void write(byte[] b) throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.write(b);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+
+ @Override
+ public void write(byte[] b, int off, int len) throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.write(b, off, len);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+
+ @Override
+ public void flush() throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.flush();
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+
+ @Override
+ public void close() throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.close();
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js
index acf2d93b718b2..c55f752620dfd 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js
@@ -20,7 +20,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
-module.exports={graphlib:require("./lib/graphlib"),dagre:require("./lib/dagre"),intersect:require("./lib/intersect"),render:require("./lib/render"),util:require("./lib/util"),version:require("./lib/version")}},{"./lib/dagre":8,"./lib/graphlib":9,"./lib/intersect":10,"./lib/render":23,"./lib/util":25,"./lib/version":26}],2:[function(require,module,exports){var util=require("./util");module.exports={"default":normal,normal:normal,vee:vee,undirected:undirected};function normal(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function vee(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 L 4 5 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function undirected(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 5 L 10 5").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}},{"./util":25}],3:[function(require,module,exports){var _=require("./lodash"),addLabel=require("./label/add-label"),util=require("./util");module.exports=createClusters;function createClusters(selection,g){var clusters=g.nodes().filter(function(v){return util.isSubgraph(g,v)}),svgClusters=selection.selectAll("g.cluster").data(clusters,function(v){return v});var makeClusterIdentifier=function(v){return"cluster_"+v.replace(/^cluster/,"")};svgClusters.enter().append("g").attr("id",makeClusterIdentifier).attr("name",function(v){return g.node(v).label}).classed("cluster",true).style("opacity",0).append("rect");var sortedClusters=util.orderByRank(g,svgClusters.data());for(var i=0;i0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){
+module.exports={graphlib:require("./lib/graphlib"),dagre:require("./lib/dagre"),intersect:require("./lib/intersect"),render:require("./lib/render"),util:require("./lib/util"),version:require("./lib/version")}},{"./lib/dagre":8,"./lib/graphlib":9,"./lib/intersect":10,"./lib/render":23,"./lib/util":25,"./lib/version":26}],2:[function(require,module,exports){var util=require("./util");module.exports={"default":normal,normal:normal,vee:vee,undirected:undirected};function normal(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function vee(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 L 4 5 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function undirected(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 5 L 10 5").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}},{"./util":25}],3:[function(require,module,exports){var _=require("./lodash"),addLabel=require("./label/add-label"),util=require("./util");module.exports=createClusters;function createClusters(selection,g){var clusters=g.nodes().filter(function(v){return util.isSubgraph(g,v)}),svgClusters=selection.selectAll("g.cluster").data(clusters,function(v){return v});var makeClusterIdentifier=function(v){return"cluster_"+v.replace(/^cluster/,"")};svgClusters.enter().append("g").attr("class",makeClusterIdentifier).attr("name",function(v){return g.node(v).label}).classed("cluster",true).style("opacity",0).append("rect");var sortedClusters=util.orderByRank(g,svgClusters.data());for(var i=0;i0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){
parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{};
diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css
index 18c72694f3e2d..eedefb44b96fc 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css
@@ -44,6 +44,10 @@
stroke-width: 1px;
}
+#dag-viz-graph div#empty-dag-viz-message {
+ margin: 15px;
+}
+
/* Job page specific styles */
#dag-viz-graph svg.job marker#marker-arrow path {
@@ -57,7 +61,7 @@
stroke-width: 1px;
}
-#dag-viz-graph svg.job g.cluster[id*="stage"] rect {
+#dag-viz-graph svg.job g.cluster[class*="stage"] rect {
fill: #FFFFFF;
stroke: #FF99AC;
stroke-width: 1px;
@@ -79,7 +83,7 @@
stroke-width: 1px;
}
-#dag-viz-graph svg.stage g.cluster[id*="stage"] rect {
+#dag-viz-graph svg.stage g.cluster[class*="stage"] rect {
fill: #FFFFFF;
stroke: #FFA6B6;
stroke-width: 1px;
diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
index f7d0d3c61457c..8138eb0d4f390 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
@@ -86,7 +86,7 @@ function toggleDagViz(forJob) {
$(arrowSelector).toggleClass('arrow-open');
var shouldShow = $(arrowSelector).hasClass("arrow-open");
if (shouldShow) {
- var shouldRender = graphContainer().select("svg").empty();
+ var shouldRender = graphContainer().select("*").empty();
if (shouldRender) {
renderDagViz(forJob);
}
@@ -108,7 +108,7 @@ function toggleDagViz(forJob) {
* Output DOM hierarchy:
* div#dag-viz-graph >
* svg >
- * g#cluster_stage_[stageId]
+ * g.cluster_stage_[stageId]
*
* Note that the input metadata is populated by o.a.s.ui.UIUtils.showDagViz.
* Any changes in the input format here must be reflected there.
@@ -117,10 +117,18 @@ function renderDagViz(forJob) {
// If there is not a dot file to render, fail fast and report error
var jobOrStage = forJob ? "job" : "stage";
- if (metadataContainer().empty()) {
- graphContainer()
- .append("div")
- .text("No visualization information available for this " + jobOrStage);
+ if (metadataContainer().empty() ||
+ metadataContainer().selectAll("div").empty()) {
+ var message =
+ "No visualization information available for this " + jobOrStage + "! " +
+ "If this is an old " + jobOrStage + ", its visualization metadata may have been " +
+ "cleaned up over time. You may consider increasing the value of ";
+ if (forJob) {
+ message += "spark.ui.retainedJobs and spark.ui.retainedStages .";
+ } else {
+ message += "spark.ui.retainedStages ";
+ }
+ graphContainer().append("div").attr("id", "empty-dag-viz-message").html(message);
return;
}
@@ -137,7 +145,7 @@ function renderDagViz(forJob) {
// Find cached RDDs and mark them as such
metadataContainer().selectAll(".cached-rdd").each(function(v) {
var nodeId = VizConstants.nodePrefix + d3.select(this).text();
- svg.selectAll("#" + nodeId).classed("cached", true);
+ svg.selectAll("g." + nodeId).classed("cached", true);
});
resizeSvg(svg);
@@ -192,14 +200,10 @@ function renderDagVizForJob(svgContainer) {
if (i > 0) {
var existingStages = svgContainer
.selectAll("g.cluster")
- .filter("[id*=\"" + VizConstants.stageClusterPrefix + "\"]");
+ .filter("[class*=\"" + VizConstants.stageClusterPrefix + "\"]");
if (!existingStages.empty()) {
var lastStage = d3.select(existingStages[0].pop());
- var lastStageId = lastStage.attr("id");
- var lastStageWidth = toFloat(svgContainer
- .select("#" + lastStageId)
- .select("rect")
- .attr("width"));
+ var lastStageWidth = toFloat(lastStage.select("rect").attr("width"));
var lastStagePosition = getAbsolutePosition(lastStage);
var offset = lastStagePosition.x + lastStageWidth + VizConstants.stageSep;
container.attr("transform", "translate(" + offset + ", 0)");
@@ -372,14 +376,14 @@ function getAbsolutePosition(d3selection) {
function connectRDDs(fromRDDId, toRDDId, edgesContainer, svgContainer) {
var fromNodeId = VizConstants.nodePrefix + fromRDDId;
var toNodeId = VizConstants.nodePrefix + toRDDId;
- var fromPos = getAbsolutePosition(svgContainer.select("#" + fromNodeId));
- var toPos = getAbsolutePosition(svgContainer.select("#" + toNodeId));
+ var fromPos = getAbsolutePosition(svgContainer.select("g." + fromNodeId));
+ var toPos = getAbsolutePosition(svgContainer.select("g." + toNodeId));
// On the job page, RDDs are rendered as dots (circles). When rendering the path,
// we need to account for the radii of these circles. Otherwise the arrow heads
// will bleed into the circle itself.
var delta = toFloat(svgContainer
- .select("g.node#" + toNodeId)
+ .select("g.node." + toNodeId)
.select("circle")
.attr("r"));
if (fromPos.x < toPos.x) {
@@ -431,10 +435,35 @@ function addTooltipsForRDDs(svgContainer) {
node.select("circle")
.attr("data-toggle", "tooltip")
.attr("data-placement", "bottom")
- .attr("title", tooltipText)
+ .attr("title", tooltipText);
}
+ // Link tooltips for all nodes that belong to the same RDD
+ node.on("mouseenter", function() { triggerTooltipForRDD(node, true); });
+ node.on("mouseleave", function() { triggerTooltipForRDD(node, false); });
});
- $("[data-toggle=tooltip]").tooltip({container: "body"});
+
+ $("[data-toggle=tooltip]")
+ .filter("g.node circle")
+ .tooltip({ container: "body", trigger: "manual" });
+}
+
+/*
+ * (Job page only) Helper function to show or hide tooltips for all nodes
+ * in the graph that refer to the same RDD the specified node represents.
+ */
+function triggerTooltipForRDD(d3node, show) {
+ var classes = d3node.node().classList;
+ for (var i = 0; i < classes.length; i++) {
+ var clazz = classes[i];
+ var isRDDClass = clazz.indexOf(VizConstants.nodePrefix) == 0;
+ if (isRDDClass) {
+ graphContainer().selectAll("g." + clazz).each(function() {
+ var circle = d3.select(this).select("circle").node();
+ var showOrHide = show ? "show" : "hide";
+ $(circle).tooltip(showOrHide);
+ });
+ }
+ }
}
/* Helper function to convert attributes to numeric values. */
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 0c4d28f786edd..a5d831c7e68ad 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -313,7 +313,8 @@ object SparkEnv extends Logging {
// Let the user specify short names for shuffle managers
val shortShuffleMgrNames = Map(
"hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
- "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
+ "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager",
+ "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager")
val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 02a94baf372d9..f7fa37e4cdcdc 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1524,7 +1524,7 @@ abstract class RDD[T: ClassTag](
* doCheckpoint() is called recursively on the parent RDDs.
*/
private[spark] def doCheckpoint(): Unit = {
- RDDOperationScope.withScope(sc, "checkpoint", false, true) {
+ RDDOperationScope.withScope(sc, "checkpoint", allowNesting = false, ignoreParent = true) {
if (!doCheckpointCalled) {
doCheckpointCalled = true
if (checkpointData.isDefined) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala
index 66df1ebd4d5b0..b3dd4d757df3e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala
@@ -96,7 +96,7 @@ private[spark] object RDDOperationScope {
sc: SparkContext,
allowNesting: Boolean = false)(body: => T): T = {
val callerMethodName = Thread.currentThread.getStackTrace()(3).getMethodName
- withScope[T](sc, callerMethodName, allowNesting)(body)
+ withScope[T](sc, callerMethodName, allowNesting, ignoreParent = false)(body)
}
/**
@@ -116,7 +116,7 @@ private[spark] object RDDOperationScope {
sc: SparkContext,
name: String,
allowNesting: Boolean,
- ignoreParent: Boolean = false)(body: => T): T = {
+ ignoreParent: Boolean)(body: => T): T = {
// Save the old scope to restore it later
val scopeKey = SparkContext.RDD_SCOPE_KEY
val noOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index dfbde7c8a1b0d..698d1384d580d 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -121,6 +121,8 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true)
+ protected def this() = this(new SparkConf()) // For deserialization only
+
override def newInstance(): SerializerInstance = {
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index 6ad427bcac7f9..6c3b3080d2605 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -76,7 +76,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
private val consolidateShuffleFiles =
conf.getBoolean("spark.shuffle.consolidateFiles", false)
- // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
index f6e6fe5defe09..4cc4ef5f1886e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -17,14 +17,17 @@
package org.apache.spark.shuffle
+import java.io.IOException
+
import org.apache.spark.scheduler.MapStatus
/**
* Obtained inside a map task to write out records to the shuffle system.
*/
-private[spark] trait ShuffleWriter[K, V] {
+private[spark] abstract class ShuffleWriter[K, V] {
/** Write a sequence of records to this task's output */
- def write(records: Iterator[_ <: Product2[K, V]]): Unit
+ @throws[IOException]
+ def write(records: Iterator[Product2[K, V]]): Unit
/** Close this writer, passing along whether the map completed */
def stop(success: Boolean): Option[MapStatus]
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 897f0a5dc5bcc..eb87cee15903c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -49,7 +49,7 @@ private[spark] class HashShuffleWriter[K, V](
writeMetrics)
/** Write a bunch of records to this task's output */
- override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+ override def write(records: Iterator[Product2[K, V]]): Unit = {
val iter = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
dep.aggregator.get.combineValuesByKey(records, context)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 15842941daaab..d7fab351ca3b8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
true
}
- override def shuffleBlockResolver: IndexShuffleBlockResolver = {
+ override val shuffleBlockResolver: IndexShuffleBlockResolver = {
indexShuffleBlockResolver
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index add2656294ca2..c9dd6bfc4c219 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -48,7 +48,7 @@ private[spark] class SortShuffleWriter[K, V, C](
context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics)
/** Write a bunch of records to this task's output */
- override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+ override def write(records: Iterator[Product2[K, V]]): Unit = {
if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
sorter = new ExternalSorter[K, V, C](
diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
new file mode 100644
index 0000000000000..f2bfef376d3ca
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import java.util.Collections
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark._
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.sort.SortShuffleManager
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
+ */
+private[spark] class UnsafeShuffleHandle[K, V](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, V])
+ extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
+
+private[spark] object UnsafeShuffleManager extends Logging {
+
+ /**
+ * The maximum number of shuffle output partitions that UnsafeShuffleManager supports.
+ */
+ val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
+
+ /**
+ * Helper method for determining whether a shuffle should use the optimized unsafe shuffle
+ * path or whether it should fall back to the original sort-based shuffle.
+ */
+ def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
+ val shufId = dependency.shuffleId
+ val serializer = Serializer.getSerializer(dependency.serializer)
+ if (!serializer.supportsRelocationOfSerializedObjects) {
+ log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
+ s"${serializer.getClass.getName}, does not support object relocation")
+ false
+ } else if (dependency.aggregator.isDefined) {
+ log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
+ false
+ } else if (dependency.keyOrdering.isDefined) {
+ log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined")
+ false
+ } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) {
+ log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " +
+ s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions")
+ false
+ } else {
+ log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
+ true
+ }
+ }
+}
+
+/**
+ * A shuffle implementation that uses directly-managed memory to implement several performance
+ * optimizations for certain types of shuffles. In cases where the new performance optimizations
+ * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those
+ * shuffles.
+ *
+ * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold:
+ *
+ * - The shuffle dependency specifies no aggregation or output ordering.
+ * - The shuffle serializer supports relocation of serialized values (this is currently supported
+ * by KryoSerializer and Spark SQL's custom serializers).
+ * - The shuffle produces fewer than 16777216 output partitions.
+ * - No individual record is larger than 128 MB when serialized.
+ *
+ * In addition, extra spill-merging optimizations are automatically applied when the shuffle
+ * compression codec supports concatenation of serialized streams. This is currently supported by
+ * Spark's LZF serializer.
+ *
+ * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager.
+ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
+ * written to a single map output file. Reducers fetch contiguous regions of this file in order to
+ * read their portion of the map output. In cases where the map output data is too large to fit in
+ * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
+ * to produce the final output file.
+ *
+ * UnsafeShuffleManager optimizes this process in several ways:
+ *
+ * - Its sort operates on serialized binary data rather than Java objects, which reduces memory
+ * consumption and GC overheads. This optimization requires the record serializer to have certain
+ * properties to allow serialized records to be re-ordered without requiring deserialization.
+ * See SPARK-4550, where this optimization was first proposed and implemented, for more details.
+ *
+ * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts
+ * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
+ * record in the sorting array, this fits more of the array into cache.
+ *
+ * - The spill merging procedure operates on blocks of serialized records that belong to the same
+ * partition and does not need to deserialize records during the merge.
+ *
+ * - When the spill compression codec supports concatenation of compressed data, the spill merge
+ * simply concatenates the serialized and compressed spill partitions to produce the final output
+ * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used
+ * and avoids the need to allocate decompression or copying buffers during the merge.
+ *
+ * For more details on UnsafeShuffleManager's design, see SPARK-7081.
+ */
+private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
+
+ if (!conf.getBoolean("spark.shuffle.spill", true)) {
+ logWarning(
+ "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " +
+ "manager; its optimized shuffles will continue to spill to disk when necessary.")
+ }
+
+ private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
+ private[this] val shufflesThatFellBackToSortShuffle =
+ Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
+ private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()
+
+ /**
+ * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+ */
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
+ new UnsafeShuffleHandle[K, V](
+ shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else {
+ new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ }
+ }
+
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): ShuffleReader[K, C] = {
+ sortShuffleManager.getReader(handle, startPartition, endPartition, context)
+ }
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Int,
+ context: TaskContext): ShuffleWriter[K, V] = {
+ handle match {
+ case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] =>
+ numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
+ val env = SparkEnv.get
+ new UnsafeShuffleWriter(
+ env.blockManager,
+ shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+ context.taskMemoryManager(),
+ env.shuffleMemoryManager,
+ unsafeShuffleHandle,
+ mapId,
+ context,
+ env.conf)
+ case other =>
+ shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
+ sortShuffleManager.getWriter(handle, mapId, context)
+ }
+ }
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
+ sortShuffleManager.unregisterShuffle(shuffleId)
+ } else {
+ Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
+ (0 until numMaps).foreach { mapId =>
+ shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
+ }
+ }
+ true
+ }
+ }
+
+ override val shuffleBlockResolver: IndexShuffleBlockResolver = {
+ sortShuffleManager.shuffleBlockResolver
+ }
+
+ /** Shut down this ShuffleManager. */
+ override def stop(): Unit = {
+ sortShuffleManager.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 8bc4e205bc3c6..a33f22ef52687 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -86,16 +86,6 @@ private[spark] class DiskBlockObjectWriter(
extends BlockObjectWriter(blockId)
with Logging
{
- /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
- private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
- override def write(i: Int): Unit = callWithTiming(out.write(i))
- override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b))
- override def write(b: Array[Byte], off: Int, len: Int): Unit = {
- callWithTiming(out.write(b, off, len))
- }
- override def close(): Unit = out.close()
- override def flush(): Unit = out.flush()
- }
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
@@ -136,7 +126,7 @@ private[spark] class DiskBlockObjectWriter(
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
}
fos = new FileOutputStream(file, true)
- ts = new TimeTrackingOutputStream(fos)
+ ts = new TimeTrackingOutputStream(writeMetrics, fos)
channel = fos.getChannel()
bs = compressStream(new BufferedOutputStream(ts, bufferSize))
objOut = serializerInstance.serializeStream(bs)
@@ -150,9 +140,9 @@ private[spark] class DiskBlockObjectWriter(
if (syncWrites) {
// Force outstanding writes to disk and track how long it takes
objOut.flush()
- callWithTiming {
- fos.getFD.sync()
- }
+ val start = System.nanoTime()
+ fos.getFD.sync()
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
}
} {
objOut.close()
@@ -251,12 +241,6 @@ private[spark] class DiskBlockObjectWriter(
reportedPosition = pos
}
- private def callWithTiming(f: => Unit) = {
- val start = System.nanoTime()
- f
- writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
- }
-
// For testing
private[spark] override def flush() {
objOut.flush()
diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
index 2884a49f31122..3b77a1e12cc45 100644
--- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala
@@ -27,19 +27,29 @@ import org.apache.spark.ui.SparkUI
* A SparkListener that constructs a DAG of RDD operations.
*/
private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListener {
- private val jobIdToStageIds = new mutable.HashMap[Int, Seq[Int]]
- private val stageIdToGraph = new mutable.HashMap[Int, RDDOperationGraph]
- private val stageIds = new mutable.ArrayBuffer[Int]
+ private[ui] val jobIdToStageIds = new mutable.HashMap[Int, Seq[Int]]
+ private[ui] val stageIdToGraph = new mutable.HashMap[Int, RDDOperationGraph]
+
+ // Keep track of the order in which these are inserted so we can remove old ones
+ private[ui] val jobIds = new mutable.ArrayBuffer[Int]
+ private[ui] val stageIds = new mutable.ArrayBuffer[Int]
// How many jobs or stages to retain graph metadata for
+ private val retainedJobs =
+ conf.getInt("spark.ui.retainedJobs", SparkUI.DEFAULT_RETAINED_JOBS)
private val retainedStages =
conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES)
/** Return the graph metadata for the given stage, or None if no such information exists. */
def getOperationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = {
- jobIdToStageIds.get(jobId)
- .map { sids => sids.flatMap { sid => stageIdToGraph.get(sid) } }
- .getOrElse { Seq.empty }
+ val stageIds = jobIdToStageIds.get(jobId).getOrElse { Seq.empty }
+ val graphs = stageIds.flatMap { sid => stageIdToGraph.get(sid) }
+ // If the metadata for some stages have been removed, do not bother rendering this job
+ if (stageIds.size != graphs.size) {
+ Seq.empty
+ } else {
+ graphs
+ }
}
/** Return the graph metadata for the given stage, or None if no such information exists. */
@@ -50,15 +60,22 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen
/** On job start, construct a RDDOperationGraph for each stage in the job for display later. */
override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized {
val jobId = jobStart.jobId
- val stageInfos = jobStart.stageInfos
+ jobIds += jobId
+ jobIdToStageIds(jobId) = jobStart.stageInfos.map(_.stageId).sorted
- stageInfos.foreach { stageInfo =>
- stageIds += stageInfo.stageId
- stageIdToGraph(stageInfo.stageId) = RDDOperationGraph.makeOperationGraph(stageInfo)
+ // Remove state for old jobs
+ if (jobIds.size >= retainedJobs) {
+ val toRemove = math.max(retainedJobs / 10, 1)
+ jobIds.take(toRemove).foreach { id => jobIdToStageIds.remove(id) }
+ jobIds.trimStart(toRemove)
}
- jobIdToStageIds(jobId) = stageInfos.map(_.stageId).sorted
+ }
- // Remove graph metadata for old stages
+ /** Remove graph metadata for old stages */
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized {
+ val stageInfo = stageSubmitted.stageInfo
+ stageIds += stageInfo.stageId
+ stageIdToGraph(stageInfo.stageId) = RDDOperationGraph.makeOperationGraph(stageInfo)
if (stageIds.size >= retainedStages) {
val toRemove = math.max(retainedStages / 10, 1)
stageIds.take(toRemove).foreach { id => stageIdToGraph.remove(id) }
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index b850973145077..df2d6ad3b41a4 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -90,7 +90,7 @@ class ExternalAppendOnlyMap[K, V, C](
// Number of bytes spilled in total
private var _diskBytesSpilled = 0L
- // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize =
sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 7d5cf7b61e56a..3b9d14f9372b6 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -110,7 +110,7 @@ private[spark] class ExternalSorter[K, V, C](
private val conf = SparkEnv.get.conf
private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
- // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true)
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
new file mode 100644
index 0000000000000..db9e82759090a
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
+
+public class PackedRecordPointerSuite {
+
+ @Test
+ public void heap() {
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock page0 = memoryManager.allocatePage(100);
+ final MemoryBlock page1 = memoryManager.allocatePage(100);
+ final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+ page1.getBaseOffset() + 42);
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+ assertEquals(360, packedPointer.getPartitionId());
+ final long recordPointer = packedPointer.getRecordPointer();
+ assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+ assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+ assertEquals(addressInPage1, recordPointer);
+ memoryManager.cleanUpAllAllocatedMemory();
+ }
+
+ @Test
+ public void offHeap() {
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+ final MemoryBlock page0 = memoryManager.allocatePage(100);
+ final MemoryBlock page1 = memoryManager.allocatePage(100);
+ final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+ page1.getBaseOffset() + 42);
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+ assertEquals(360, packedPointer.getPartitionId());
+ final long recordPointer = packedPointer.getRecordPointer();
+ assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+ assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+ assertEquals(addressInPage1, recordPointer);
+ memoryManager.cleanUpAllAllocatedMemory();
+ }
+
+ @Test
+ public void maximumPartitionIdCanBeEncoded() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID));
+ assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
+ }
+
+ @Test
+ public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ try {
+ // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
+ packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1));
+ assertFalse(MAXIMUM_PARTITION_ID + 1 == packedPointer.getPartitionId());
+ } catch (AssertionError e ) {
+ // pass
+ }
+ }
+
+ @Test
+ public void maximumOffsetInPageCanBeEncoded() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1);
+ packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+ assertEquals(address, packedPointer.getRecordPointer());
+ }
+
+ @Test
+ public void offsetsPastMaxOffsetInPageWillOverflow() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES);
+ packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+ assertEquals(0, packedPointer.getRecordPointer());
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
new file mode 100644
index 0000000000000..8fa72597db24d
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class UnsafeShuffleInMemorySorterSuite {
+
+ private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
+ final byte[] strBytes = new byte[strLength];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset,
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET, strLength);
+ return new String(strBytes);
+ }
+
+ @Test
+ public void testSortingEmptyInput() {
+ final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100);
+ final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+ assert(!iter.hasNext());
+ }
+
+ @Test
+ public void testBasicSorting() throws Exception {
+ final String[] dataToSort = new String[] {
+ "Boba",
+ "Pearls",
+ "Tapioca",
+ "Taho",
+ "Condensed Milk",
+ "Jasmine",
+ "Milk Tea",
+ "Lychee",
+ "Mango"
+ };
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+ final Object baseObject = dataPage.getBaseObject();
+ final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+ final HashPartitioner hashPartitioner = new HashPartitioner(4);
+
+ // Write the records into the data page and store pointers into the sorter
+ long position = dataPage.getBaseOffset();
+ for (String str : dataToSort) {
+ final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
+ final byte[] strBytes = str.getBytes("utf-8");
+ PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length);
+ position += 4;
+ PlatformDependent.copyMemory(
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ baseObject,
+ position,
+ strBytes.length);
+ position += strBytes.length;
+ sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
+ }
+
+ // Sort the records
+ final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+ int prevPartitionId = -1;
+ Arrays.sort(dataToSort);
+ for (int i = 0; i < dataToSort.length; i++) {
+ Assert.assertTrue(iter.hasNext());
+ iter.loadNext();
+ final int partitionId = iter.packedRecordPointer.getPartitionId();
+ Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
+ Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
+ partitionId >= prevPartitionId);
+ final long recordAddress = iter.packedRecordPointer.getRecordPointer();
+ final int recordLength = PlatformDependent.UNSAFE.getInt(
+ memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
+ final String str = getStringFromDataPage(
+ memoryManager.getPage(recordAddress),
+ memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length
+ recordLength);
+ Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
+ }
+ Assert.assertFalse(iter.hasNext());
+ }
+
+ @Test
+ public void testSortingManyNumbers() throws Exception {
+ UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+ int[] numbersToSort = new int[128000];
+ Random random = new Random(16);
+ for (int i = 0; i < numbersToSort.length; i++) {
+ numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
+ sorter.insertRecord(0, numbersToSort[i]);
+ }
+ Arrays.sort(numbersToSort);
+ int[] sorterResult = new int[numbersToSort.length];
+ UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+ int j = 0;
+ while (iter.hasNext()) {
+ iter.loadNext();
+ sorterResult[j] = iter.packedRecordPointer.getPartitionId();
+ j += 1;
+ }
+ Assert.assertArrayEquals(numbersToSort, sorterResult);
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
new file mode 100644
index 0000000000000..730d265c87f88
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.*;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import scala.*;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+import scala.runtime.AbstractFunction1;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.io.ByteStreams;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.*;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZ4CompressionCodec;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.io.SnappyCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.serializer.*;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeShuffleWriterSuite {
+
+ static final int NUM_PARTITITONS = 4;
+ final TaskMemoryManager taskMemoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
+ File mergedOutputFile;
+ File tempDir;
+ long[] partitionSizesInMergedFile;
+ final LinkedList spillFilesCreated = new LinkedList();
+ SparkConf conf;
+ final Serializer serializer = new KryoSerializer(new SparkConf());
+ TaskMetrics taskMetrics;
+
+ @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+ @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
+ @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+ @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep;
+
+ private final class CompressStream extends AbstractFunction1 {
+ @Override
+ public OutputStream apply(OutputStream stream) {
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
+ } else {
+ return stream;
+ }
+ }
+ }
+
+ @After
+ public void tearDown() {
+ Utils.deleteRecursively(tempDir);
+ final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
+ if (leakedMemory != 0) {
+ fail("Test leaked " + leakedMemory + " bytes of managed memory");
+ }
+ }
+
+ @Before
+ @SuppressWarnings("unchecked")
+ public void setUp() throws IOException {
+ MockitoAnnotations.initMocks(this);
+ tempDir = Utils.createTempDir("test", "test");
+ mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
+ partitionSizesInMergedFile = null;
+ spillFilesCreated.clear();
+ conf = new SparkConf();
+ taskMetrics = new TaskMetrics();
+
+ when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+
+ when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+ when(blockManager.getDiskWriter(
+ any(BlockId.class),
+ any(File.class),
+ any(SerializerInstance.class),
+ anyInt(),
+ any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() {
+ @Override
+ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+ Object[] args = invocationOnMock.getArguments();
+
+ return new DiskBlockObjectWriter(
+ (BlockId) args[0],
+ (File) args[1],
+ (SerializerInstance) args[2],
+ (Integer) args[3],
+ new CompressStream(),
+ false,
+ (ShuffleWriteMetrics) args[4]
+ );
+ }
+ });
+ when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer(
+ new Answer() {
+ @Override
+ public InputStream answer(InvocationOnMock invocation) throws Throwable {
+ assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+ InputStream is = (InputStream) invocation.getArguments()[1];
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
+ } else {
+ return is;
+ }
+ }
+ }
+ );
+
+ when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer(
+ new Answer() {
+ @Override
+ public OutputStream answer(InvocationOnMock invocation) throws Throwable {
+ assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+ OutputStream os = (OutputStream) invocation.getArguments()[1];
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
+ } else {
+ return os;
+ }
+ }
+ }
+ );
+
+ when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
+ doAnswer(new Answer() {
+ @Override
+ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+ partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
+ return null;
+ }
+ }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
+
+ when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+ new Answer>() {
+ @Override
+ public Tuple2 answer(
+ InvocationOnMock invocationOnMock) throws Throwable {
+ TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
+ File file = File.createTempFile("spillFile", ".spill", tempDir);
+ spillFilesCreated.add(file);
+ return Tuple2$.MODULE$.apply(blockId, file);
+ }
+ });
+
+ when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+
+ when(shuffleDep.serializer()).thenReturn(Option.apply(serializer));
+ when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
+ }
+
+ private UnsafeShuffleWriter createWriter(
+ boolean transferToEnabled) throws IOException {
+ conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
+ return new UnsafeShuffleWriter(
+ blockManager,
+ shuffleBlockResolver,
+ taskMemoryManager,
+ shuffleMemoryManager,
+ new UnsafeShuffleHandle(0, 1, shuffleDep),
+ 0, // map id
+ taskContext,
+ conf
+ );
+ }
+
+ private void assertSpillFilesWereCleanedUp() {
+ for (File spillFile : spillFilesCreated) {
+ assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+ spillFile.exists());
+ }
+ }
+
+ private List> readRecordsFromFile() throws IOException {
+ final ArrayList> recordsList = new ArrayList>();
+ long startOffset = 0;
+ for (int i = 0; i < NUM_PARTITITONS; i++) {
+ final long partitionSize = partitionSizesInMergedFile[i];
+ if (partitionSize > 0) {
+ InputStream in = new FileInputStream(mergedOutputFile);
+ ByteStreams.skipFully(in, startOffset);
+ in = new LimitedInputStream(in, partitionSize);
+ if (conf.getBoolean("spark.shuffle.compress", true)) {
+ in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
+ }
+ DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in);
+ Iterator> records = recordsStream.asKeyValueIterator();
+ while (records.hasNext()) {
+ Tuple2 record = records.next();
+ assertEquals(i, hashPartitioner.getPartition(record._1()));
+ recordsList.add(record);
+ }
+ recordsStream.close();
+ startOffset += partitionSize;
+ }
+ }
+ return recordsList;
+ }
+
+ @Test(expected=IllegalStateException.class)
+ public void mustCallWriteBeforeSuccessfulStop() throws IOException {
+ createWriter(false).stop(true);
+ }
+
+ @Test
+ public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
+ createWriter(false).stop(false);
+ }
+
+ @Test
+ public void writeEmptyIterator() throws Exception {
+ final UnsafeShuffleWriter writer = createWriter(true);
+ writer.write(Collections.>emptyIterator());
+ final Option mapStatus = writer.stop(true);
+ assertTrue(mapStatus.isDefined());
+ assertTrue(mergedOutputFile.exists());
+ assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
+ assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten());
+ assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten());
+ assertEquals(0, taskMetrics.diskBytesSpilled());
+ assertEquals(0, taskMetrics.memoryBytesSpilled());
+ }
+
+ @Test
+ public void writeWithoutSpilling() throws Exception {
+ // In this example, each partition should have exactly one record:
+ final ArrayList> dataToWrite =
+ new ArrayList>();
+ for (int i = 0; i < NUM_PARTITITONS; i++) {
+ dataToWrite.add(new Tuple2(i, i));
+ }
+ final UnsafeShuffleWriter writer = createWriter(true);
+ writer.write(dataToWrite.iterator());
+ final Option mapStatus = writer.stop(true);
+ assertTrue(mapStatus.isDefined());
+ assertTrue(mergedOutputFile.exists());
+
+ long sumOfPartitionSizes = 0;
+ for (long size: partitionSizesInMergedFile) {
+ // All partitions should be the same size:
+ assertEquals(partitionSizesInMergedFile[0], size);
+ sumOfPartitionSizes += size;
+ }
+ assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
+ assertEquals(
+ HashMultiset.create(dataToWrite),
+ HashMultiset.create(readRecordsFromFile()));
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertEquals(0, taskMetrics.diskBytesSpilled());
+ assertEquals(0, taskMetrics.memoryBytesSpilled());
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ private void testMergingSpills(
+ boolean transferToEnabled,
+ String compressionCodecName) throws IOException {
+ if (compressionCodecName != null) {
+ conf.set("spark.shuffle.compress", "true");
+ conf.set("spark.io.compression.codec", compressionCodecName);
+ } else {
+ conf.set("spark.shuffle.compress", "false");
+ }
+ final UnsafeShuffleWriter writer = createWriter(transferToEnabled);
+ final ArrayList> dataToWrite =
+ new ArrayList>();
+ for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
+ dataToWrite.add(new Tuple2(i, i));
+ }
+ writer.insertRecordIntoSorter(dataToWrite.get(0));
+ writer.insertRecordIntoSorter(dataToWrite.get(1));
+ writer.insertRecordIntoSorter(dataToWrite.get(2));
+ writer.insertRecordIntoSorter(dataToWrite.get(3));
+ writer.forceSorterToSpill();
+ writer.insertRecordIntoSorter(dataToWrite.get(4));
+ writer.insertRecordIntoSorter(dataToWrite.get(5));
+ writer.closeAndWriteOutput();
+ final Option mapStatus = writer.stop(true);
+ assertTrue(mapStatus.isDefined());
+ assertTrue(mergedOutputFile.exists());
+ assertEquals(2, spillFilesCreated.size());
+
+ long sumOfPartitionSizes = 0;
+ for (long size: partitionSizesInMergedFile) {
+ sumOfPartitionSizes += size;
+ }
+ assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
+
+ assertEquals(
+ HashMultiset.create(dataToWrite),
+ HashMultiset.create(readRecordsFromFile()));
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+ assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+ assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndLZF() throws Exception {
+ testMergingSpills(true, LZFCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndLZF() throws Exception {
+ testMergingSpills(false, LZFCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndLZ4() throws Exception {
+ testMergingSpills(true, LZ4CompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
+ testMergingSpills(false, LZ4CompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndSnappy() throws Exception {
+ testMergingSpills(true, SnappyCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
+ testMergingSpills(false, SnappyCompressionCodec.class.getName());
+ }
+
+ @Test
+ public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
+ testMergingSpills(true, null);
+ }
+
+ @Test
+ public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
+ testMergingSpills(false, null);
+ }
+
+ @Test
+ public void writeEnoughDataToTriggerSpill() throws Exception {
+ when(shuffleMemoryManager.tryToAcquire(anyLong()))
+ .then(returnsFirstArg()) // Allocate initial sort buffer
+ .then(returnsFirstArg()) // Allocate initial data page
+ .thenReturn(0L) // Deny request to allocate new data page
+ .then(returnsFirstArg()); // Grant new sort buffer and data page.
+ final UnsafeShuffleWriter writer = createWriter(false);
+ final ArrayList> dataToWrite = new ArrayList>();
+ final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
+ for (int i = 0; i < 128 + 1; i++) {
+ dataToWrite.add(new Tuple2(i, bigByteArray));
+ }
+ writer.write(dataToWrite.iterator());
+ verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+ assertEquals(2, spillFilesCreated.size());
+ writer.stop(true);
+ readRecordsFromFile();
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+ assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+ assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ @Test
+ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
+ when(shuffleMemoryManager.tryToAcquire(anyLong()))
+ .then(returnsFirstArg()) // Allocate initial sort buffer
+ .then(returnsFirstArg()) // Allocate initial data page
+ .thenReturn(0L) // Deny request to grow sort buffer
+ .then(returnsFirstArg()); // Grant new sort buffer and data page.
+ final UnsafeShuffleWriter writer = createWriter(false);
+ final ArrayList> dataToWrite = new ArrayList>();
+ for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
+ dataToWrite.add(new Tuple2(i, i));
+ }
+ writer.write(dataToWrite.iterator());
+ verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+ assertEquals(2, spillFilesCreated.size());
+ writer.stop(true);
+ readRecordsFromFile();
+ assertSpillFilesWereCleanedUp();
+ ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+ assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+ assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+ assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+ assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+ assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+ }
+
+ @Test
+ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
+ final UnsafeShuffleWriter writer = createWriter(false);
+ final ArrayList> dataToWrite =
+ new ArrayList>();
+ final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
+ new Random(42).nextBytes(bytes);
+ dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes)));
+ writer.write(dataToWrite.iterator());
+ writer.stop(true);
+ assertEquals(
+ HashMultiset.create(dataToWrite),
+ HashMultiset.create(readRecordsFromFile()));
+ assertSpillFilesWereCleanedUp();
+ }
+
+ @Test
+ public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
+ // Use a custom serializer so that we have exact control over the size of serialized data.
+ final Serializer byteArraySerializer = new Serializer() {
+ @Override
+ public SerializerInstance newInstance() {
+ return new SerializerInstance() {
+ @Override
+ public SerializationStream serializeStream(final OutputStream s) {
+ return new SerializationStream() {
+ @Override
+ public void flush() { }
+
+ @Override
+ public SerializationStream writeObject(T t, ClassTag ev1) {
+ byte[] bytes = (byte[]) t;
+ try {
+ s.write(bytes);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ return this;
+ }
+
+ @Override
+ public void close() { }
+ };
+ }
+ public ByteBuffer serialize(T t, ClassTag ev1) { return null; }
+ public DeserializationStream deserializeStream(InputStream s) { return null; }
+ public T deserialize(ByteBuffer b, ClassLoader l, ClassTag ev1) { return null; }
+ public T deserialize(ByteBuffer bytes, ClassTag ev1) { return null; }
+ };
+ }
+ };
+ when(shuffleDep.serializer()).thenReturn(Option.apply(byteArraySerializer));
+ final UnsafeShuffleWriter writer = createWriter(false);
+ // Insert a record and force a spill so that there's something to clean up:
+ writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1]));
+ writer.forceSorterToSpill();
+ // We should be able to write a record that's right _at_ the max record size
+ final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE];
+ new Random(42).nextBytes(atMaxRecordSize);
+ writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize));
+ writer.forceSorterToSpill();
+ // Inserting a record that's larger than the max record size should fail:
+ final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1];
+ new Random(42).nextBytes(exceedsMaxRecordSize);
+ Product2 hugeRecord =
+ new Tuple2(new byte[0], exceedsMaxRecordSize);
+ try {
+ // Here, we write through the public `write()` interface instead of the test-only
+ // `insertRecordIntoSorter` interface:
+ writer.write(Collections.singletonList(hugeRecord).iterator());
+ fail("Expected exception to be thrown");
+ } catch (IOException e) {
+ // Pass
+ }
+ assertSpillFilesWereCleanedUp();
+ }
+
+ @Test
+ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
+ final UnsafeShuffleWriter writer = createWriter(false);
+ writer.insertRecordIntoSorter(new Tuple2(1, 1));
+ writer.insertRecordIntoSorter(new Tuple2(2, 2));
+ writer.forceSorterToSpill();
+ writer.insertRecordIntoSorter(new Tuple2(2, 2));
+ writer.stop(false);
+ assertSpillFilesWereCleanedUp();
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
index 8c6035fb367fe..cf6a143537889 100644
--- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.io
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import com.google.common.io.ByteStreams
import org.scalatest.FunSuite
import org.apache.spark.SparkConf
@@ -62,6 +63,14 @@ class CompressionCodecSuite extends FunSuite {
testCodec(codec)
}
+ test("lz4 does not support concatenation of serialized streams") {
+ val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName)
+ assert(codec.getClass === classOf[LZ4CompressionCodec])
+ intercept[Exception] {
+ testConcatenationOfSerializedStreams(codec)
+ }
+ }
+
test("lzf compression codec") {
val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName)
assert(codec.getClass === classOf[LZFCompressionCodec])
@@ -74,6 +83,12 @@ class CompressionCodecSuite extends FunSuite {
testCodec(codec)
}
+ test("lzf supports concatenation of serialized streams") {
+ val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName)
+ assert(codec.getClass === classOf[LZFCompressionCodec])
+ testConcatenationOfSerializedStreams(codec)
+ }
+
test("snappy compression codec") {
val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName)
assert(codec.getClass === classOf[SnappyCompressionCodec])
@@ -86,9 +101,38 @@ class CompressionCodecSuite extends FunSuite {
testCodec(codec)
}
+ test("snappy does not support concatenation of serialized streams") {
+ val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName)
+ assert(codec.getClass === classOf[SnappyCompressionCodec])
+ intercept[Exception] {
+ testConcatenationOfSerializedStreams(codec)
+ }
+ }
+
test("bad compression codec") {
intercept[IllegalArgumentException] {
CompressionCodec.createCodec(conf, "foobar")
}
}
+
+ private def testConcatenationOfSerializedStreams(codec: CompressionCodec): Unit = {
+ val bytes1: Array[Byte] = {
+ val baos = new ByteArrayOutputStream()
+ val out = codec.compressedOutputStream(baos)
+ (0 to 64).foreach(out.write)
+ out.close()
+ baos.toByteArray
+ }
+ val bytes2: Array[Byte] = {
+ val baos = new ByteArrayOutputStream()
+ val out = codec.compressedOutputStream(baos)
+ (65 to 127).foreach(out.write)
+ out.close()
+ baos.toByteArray
+ }
+ val concatenatedBytes = codec.compressedInputStream(new ByteArrayInputStream(bytes1 ++ bytes2))
+ val decompressed: Array[Byte] = new Array[Byte](128)
+ ByteStreams.readFully(concatenatedBytes, decompressed)
+ assert(decompressed.toSeq === (0 to 127))
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala
index d75ecbf1f0b4d..db465a6a9eb55 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala
@@ -61,11 +61,11 @@ class RDDOperationScopeSuite extends FunSuite with BeforeAndAfter {
var rdd1: MyCoolRDD = null
var rdd2: MyCoolRDD = null
var rdd3: MyCoolRDD = null
- RDDOperationScope.withScope(sc, "scope1", allowNesting = false) {
+ RDDOperationScope.withScope(sc, "scope1", allowNesting = false, ignoreParent = false) {
rdd1 = new MyCoolRDD(sc)
- RDDOperationScope.withScope(sc, "scope2", allowNesting = false) {
+ RDDOperationScope.withScope(sc, "scope2", allowNesting = false, ignoreParent = false) {
rdd2 = new MyCoolRDD(sc)
- RDDOperationScope.withScope(sc, "scope3", allowNesting = false) {
+ RDDOperationScope.withScope(sc, "scope3", allowNesting = false, ignoreParent = false) {
rdd3 = new MyCoolRDD(sc)
}
}
@@ -84,11 +84,13 @@ class RDDOperationScopeSuite extends FunSuite with BeforeAndAfter {
var rdd1: MyCoolRDD = null
var rdd2: MyCoolRDD = null
var rdd3: MyCoolRDD = null
- RDDOperationScope.withScope(sc, "scope1", allowNesting = true) { // allow nesting here
+ // allow nesting here
+ RDDOperationScope.withScope(sc, "scope1", allowNesting = true, ignoreParent = false) {
rdd1 = new MyCoolRDD(sc)
- RDDOperationScope.withScope(sc, "scope2", allowNesting = false) { // stop nesting here
+ // stop nesting here
+ RDDOperationScope.withScope(sc, "scope2", allowNesting = false, ignoreParent = false) {
rdd2 = new MyCoolRDD(sc)
- RDDOperationScope.withScope(sc, "scope3", allowNesting = false) {
+ RDDOperationScope.withScope(sc, "scope3", allowNesting = false, ignoreParent = false) {
rdd3 = new MyCoolRDD(sc)
}
}
@@ -107,11 +109,11 @@ class RDDOperationScopeSuite extends FunSuite with BeforeAndAfter {
var rdd1: MyCoolRDD = null
var rdd2: MyCoolRDD = null
var rdd3: MyCoolRDD = null
- RDDOperationScope.withScope(sc, "scope1", allowNesting = true) {
+ RDDOperationScope.withScope(sc, "scope1", allowNesting = true, ignoreParent = false) {
rdd1 = new MyCoolRDD(sc)
- RDDOperationScope.withScope(sc, "scope2", allowNesting = true) {
+ RDDOperationScope.withScope(sc, "scope2", allowNesting = true, ignoreParent = false) {
rdd2 = new MyCoolRDD(sc)
- RDDOperationScope.withScope(sc, "scope3", allowNesting = true) {
+ RDDOperationScope.withScope(sc, "scope3", allowNesting = true, ignoreParent = false) {
rdd3 = new MyCoolRDD(sc)
}
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
new file mode 100644
index 0000000000000..ed4d8ce632e16
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import org.apache.spark.SparkConf
+import org.scalatest.FunSuite
+
+class JavaSerializerSuite extends FunSuite {
+ test("JavaSerializer instances are serializable") {
+ val serializer = new JavaSerializer(new SparkConf())
+ val instance = serializer.newInstance()
+ instance.deserialize[JavaSerializer](instance.serialize(serializer))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
new file mode 100644
index 0000000000000..49a04a2a45280
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{FunSuite, Matchers}
+
+import org.apache.spark._
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
+
+/**
+ * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
+ * performed in other suites.
+ */
+class UnsafeShuffleManagerSuite extends FunSuite with Matchers {
+
+ import UnsafeShuffleManager.canUseUnsafeShuffle
+
+ private class RuntimeExceptionAnswer extends Answer[Object] {
+ override def answer(invocation: InvocationOnMock): Object = {
+ throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName)
+ }
+ }
+
+ private def shuffleDep(
+ partitioner: Partitioner,
+ serializer: Option[Serializer],
+ keyOrdering: Option[Ordering[Any]],
+ aggregator: Option[Aggregator[Any, Any, Any]],
+ mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = {
+ val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer())
+ doReturn(0).when(dep).shuffleId
+ doReturn(partitioner).when(dep).partitioner
+ doReturn(serializer).when(dep).serializer
+ doReturn(keyOrdering).when(dep).keyOrdering
+ doReturn(aggregator).when(dep).aggregator
+ doReturn(mapSideCombine).when(dep).mapSideCombine
+ dep
+ }
+
+ test("supported shuffle dependencies") {
+ val kryo = Some(new KryoSerializer(new SparkConf()))
+
+ assert(canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
+ when(rangePartitioner.numPartitions).thenReturn(2)
+ assert(canUseUnsafeShuffle(shuffleDep(
+ partitioner = rangePartitioner,
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ }
+
+ test("unsupported shuffle dependencies") {
+ val kryo = Some(new KryoSerializer(new SparkConf()))
+ val java = Some(new JavaSerializer(new SparkConf()))
+
+ // We only support serializers that support object relocation
+ assert(!canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = java,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ // We do not support shuffles with more than 16 million output partitions
+ assert(!canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1),
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = None,
+ mapSideCombine = false
+ )))
+
+ // We do not support shuffles that perform any kind of aggregation or sorting of keys
+ assert(!canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = Some(mock(classOf[Ordering[Any]])),
+ aggregator = None,
+ mapSideCombine = false
+ )))
+ assert(!canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+ mapSideCombine = false
+ )))
+ // We do not support shuffles that perform any kind of aggregation or sorting of keys
+ assert(!canUseUnsafeShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = Some(mock(classOf[Ordering[Any]])),
+ aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+ mapSideCombine = true
+ )))
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
new file mode 100644
index 0000000000000..6351539e91e97
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.io.filefilter.TrueFileFilter
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.util.Utils
+
+class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
+
+ // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
+
+ override def beforeAll() {
+ conf.set("spark.shuffle.manager", "tungsten-sort")
+ // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort
+ // shuffle records.
+ conf.set("spark.shuffle.memoryFraction", "0.5")
+ }
+
+ test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
+ val tmpDir = Utils.createTempDir()
+ try {
+ val myConf = conf.clone()
+ .set("spark.local.dir", tmpDir.getAbsolutePath)
+ sc = new SparkContext("local", "test", myConf)
+ // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
+ val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+ val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+ .setSerializer(new KryoSerializer(myConf))
+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+ assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
+ def getAllFiles: Set[File] =
+ FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+ val filesBeforeShuffle = getAllFiles
+ // Force the shuffle to be performed
+ shuffledRdd.count()
+ // Ensure that the shuffle actually created files that will need to be cleaned up
+ val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+ filesCreatedByShuffle.map(_.getName) should be
+ Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+ // Check that the cleanup actually removes the files
+ sc.env.blockManager.master.removeShuffle(0, blocking = true)
+ for (file <- filesCreatedByShuffle) {
+ assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+ }
+ } finally {
+ Utils.deleteRecursively(tmpDir)
+ }
+ }
+
+ test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
+ val tmpDir = Utils.createTempDir()
+ try {
+ val myConf = conf.clone()
+ .set("spark.local.dir", tmpDir.getAbsolutePath)
+ sc = new SparkContext("local", "test", myConf)
+ // Create a shuffled RDD and verify that it will actually use the old SortShuffle path
+ val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+ val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+ .setSerializer(new JavaSerializer(myConf))
+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+ assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
+ def getAllFiles: Set[File] =
+ FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+ val filesBeforeShuffle = getAllFiles
+ // Force the shuffle to be performed
+ shuffledRdd.count()
+ // Ensure that the shuffle actually created files that will need to be cleaned up
+ val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+ filesCreatedByShuffle.map(_.getName) should be
+ Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+ // Check that the cleanup actually removes the files
+ sc.env.blockManager.master.removeShuffle(0, blocking = true)
+ for (file <- filesCreatedByShuffle) {
+ assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+ }
+ } finally {
+ Utils.deleteRecursively(tmpDir)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala
new file mode 100644
index 0000000000000..619b38ac02676
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.scope
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkConf
+import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerStageSubmitted, StageInfo}
+
+class RDDOperationGraphListenerSuite extends FunSuite {
+ private var jobIdCounter = 0
+ private var stageIdCounter = 0
+
+ /** Run a job with the specified number of stages. */
+ private def runOneJob(numStages: Int, listener: RDDOperationGraphListener): Unit = {
+ assert(numStages > 0, "I will not run a job with 0 stages for you.")
+ val stageInfos = (0 until numStages).map { _ =>
+ val stageInfo = new StageInfo(stageIdCounter, 0, "s", 0, Seq.empty, Seq.empty, "d")
+ listener.onStageSubmitted(new SparkListenerStageSubmitted(stageInfo))
+ stageIdCounter += 1
+ stageInfo
+ }
+ listener.onJobStart(new SparkListenerJobStart(jobIdCounter, 0, stageInfos))
+ jobIdCounter += 1
+ }
+
+ test("listener cleans up metadata") {
+
+ val conf = new SparkConf()
+ .set("spark.ui.retainedStages", "10")
+ .set("spark.ui.retainedJobs", "10")
+
+ val listener = new RDDOperationGraphListener(conf)
+ assert(listener.jobIdToStageIds.isEmpty)
+ assert(listener.stageIdToGraph.isEmpty)
+ assert(listener.jobIds.isEmpty)
+ assert(listener.stageIds.isEmpty)
+
+ // Run a few jobs, but not enough for clean up yet
+ runOneJob(1, listener)
+ runOneJob(2, listener)
+ runOneJob(3, listener)
+ assert(listener.jobIdToStageIds.size === 3)
+ assert(listener.stageIdToGraph.size === 6)
+ assert(listener.jobIds.size === 3)
+ assert(listener.stageIds.size === 6)
+
+ // Run a few more, but this time the stages should be cleaned up, but not the jobs
+ runOneJob(5, listener)
+ runOneJob(100, listener)
+ assert(listener.jobIdToStageIds.size === 5)
+ assert(listener.stageIdToGraph.size === 9)
+ assert(listener.jobIds.size === 5)
+ assert(listener.stageIds.size === 9)
+
+ // Run a few more, but this time both jobs and stages should be cleaned up
+ (1 to 100).foreach { _ =>
+ runOneJob(1, listener)
+ }
+ assert(listener.jobIdToStageIds.size === 9)
+ assert(listener.stageIdToGraph.size === 9)
+ assert(listener.jobIds.size === 9)
+ assert(listener.stageIds.size === 9)
+
+ // Ensure we clean up old jobs and stages, not arbitrary ones
+ assert(!listener.jobIdToStageIds.contains(0))
+ assert(!listener.stageIdToGraph.contains(0))
+ assert(!listener.stageIds.contains(0))
+ assert(!listener.jobIds.contains(0))
+ }
+
+}
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
index 43c1b865b64a1..93afe50c2134f 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
@@ -18,15 +18,18 @@
package org.apache.spark.streaming.flume
import java.net.InetSocketAddress
-import java.util.concurrent.{Callable, ExecutorCompletionService, Executors}
+import java.util.concurrent._
import scala.collection.JavaConversions._
import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import scala.concurrent.duration._
+import scala.language.postfixOps
import org.apache.flume.Context
import org.apache.flume.channel.MemoryChannel
import org.apache.flume.conf.Configurables
import org.apache.flume.event.EventBuilder
+import org.scalatest.concurrent.Eventually._
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -57,11 +60,11 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging
before(beforeFunction())
- ignore("flume polling test") {
+ test("flume polling test") {
testMultipleTimes(testFlumePolling)
}
- ignore("flume polling test multiple hosts") {
+ test("flume polling test multiple hosts") {
testMultipleTimes(testFlumePollingMultipleHost)
}
@@ -100,18 +103,8 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging
Configurables.configure(sink, context)
sink.setChannel(channel)
sink.start()
- // Set up the streaming context and input streams
- val ssc = new StreamingContext(conf, batchDuration)
- val flumeStream: ReceiverInputDStream[SparkFlumeEvent] =
- FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", sink.getPort())),
- StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1)
- val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]]
- with SynchronizedBuffer[Seq[SparkFlumeEvent]]
- val outputStream = new TestOutputStream(flumeStream, outputBuffer)
- outputStream.register()
- ssc.start()
- writeAndVerify(Seq(channel), ssc, outputBuffer)
+ writeAndVerify(Seq(sink), Seq(channel))
assertChannelIsEmpty(channel)
sink.stop()
channel.stop()
@@ -142,10 +135,22 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging
Configurables.configure(sink2, context)
sink2.setChannel(channel2)
sink2.start()
+ try {
+ writeAndVerify(Seq(sink, sink2), Seq(channel, channel2))
+ assertChannelIsEmpty(channel)
+ assertChannelIsEmpty(channel2)
+ } finally {
+ sink.stop()
+ sink2.stop()
+ channel.stop()
+ channel2.stop()
+ }
+ }
+ def writeAndVerify(sinks: Seq[SparkSink], channels: Seq[MemoryChannel]) {
// Set up the streaming context and input streams
val ssc = new StreamingContext(conf, batchDuration)
- val addresses = Seq(sink.getPort(), sink2.getPort()).map(new InetSocketAddress("localhost", _))
+ val addresses = sinks.map(sink => new InetSocketAddress("localhost", sink.getPort()))
val flumeStream: ReceiverInputDStream[SparkFlumeEvent] =
FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK,
eventsPerBatch, 5)
@@ -155,61 +160,49 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging
outputStream.register()
ssc.start()
- try {
- writeAndVerify(Seq(channel, channel2), ssc, outputBuffer)
- assertChannelIsEmpty(channel)
- assertChannelIsEmpty(channel2)
- } finally {
- sink.stop()
- sink2.stop()
- channel.stop()
- channel2.stop()
- }
- }
-
- def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext,
- outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]]) {
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val executor = Executors.newCachedThreadPool()
val executorCompletion = new ExecutorCompletionService[Void](executor)
- channels.map(channel => {
+
+ val latch = new CountDownLatch(batchCount * channels.size)
+ sinks.foreach(_.countdownWhenBatchReceived(latch))
+
+ channels.foreach(channel => {
executorCompletion.submit(new TxnSubmitter(channel, clock))
})
+
for (i <- 0 until channels.size) {
executorCompletion.take()
}
- val startTime = System.currentTimeMillis()
- while (outputBuffer.size < batchCount * channels.size &&
- System.currentTimeMillis() - startTime < 15000) {
- logInfo("output.size = " + outputBuffer.size)
- Thread.sleep(100)
- }
- val timeTaken = System.currentTimeMillis() - startTime
- assert(timeTaken < 15000, "Operation timed out after " + timeTaken + " ms")
- logInfo("Stopping context")
- ssc.stop()
- val flattenedBuffer = outputBuffer.flatten
- assert(flattenedBuffer.size === totalEventsPerChannel * channels.size)
- var counter = 0
- for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) {
- val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " +
- String.valueOf(i)).getBytes("utf-8"),
- Map[String, String]("test-" + i.toString -> "header"))
- var found = false
- var j = 0
- while (j < flattenedBuffer.size && !found) {
- val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8")
- if (new String(eventToVerify.getBody, "utf-8") == strToCompare &&
- eventToVerify.getHeaders.get("test-" + i.toString)
- .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) {
- found = true
- counter += 1
+ latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received.
+ clock.advance(batchDuration.milliseconds)
+
+ // The eventually is required to ensure that all data in the batch has been processed.
+ eventually(timeout(10 seconds), interval(100 milliseconds)) {
+ val flattenedBuffer = outputBuffer.flatten
+ assert(flattenedBuffer.size === totalEventsPerChannel * channels.size)
+ var counter = 0
+ for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) {
+ val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " +
+ String.valueOf(i)).getBytes("utf-8"),
+ Map[String, String]("test-" + i.toString -> "header"))
+ var found = false
+ var j = 0
+ while (j < flattenedBuffer.size && !found) {
+ val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8")
+ if (new String(eventToVerify.getBody, "utf-8") == strToCompare &&
+ eventToVerify.getHeaders.get("test-" + i.toString)
+ .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) {
+ found = true
+ counter += 1
+ }
+ j += 1
}
- j += 1
}
+ assert(counter === totalEventsPerChannel * channels.size)
}
- assert(counter === totalEventsPerChannel * channels.size)
+ ssc.stop()
}
def assertChannelIsEmpty(channel: MemoryChannel): Unit = {
@@ -234,7 +227,6 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging
tx.commit()
tx.close()
Thread.sleep(500) // Allow some time for the events to reach
- clock.advance(batchDuration.milliseconds)
}
null
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index c9b3ff0172e2e..b381dc2cb0140 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -87,12 +87,17 @@ class NaiveBayesModel private[mllib] (
}
override def predict(testData: Vector): Double = {
+ val brzData = testData.toBreeze
modelType match {
case "Multinomial" =>
- labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
+ labels (brzArgmax (brzPi + brzTheta * brzData) )
case "Bernoulli" =>
+ if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
+ throw new SparkException(
+ s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
+ }
labels (brzArgmax (brzPi +
- (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
+ (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
@@ -293,12 +298,29 @@ class NaiveBayes private (
}
}
+ val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
+ val values = v match {
+ case SparseVector(size, indices, values) =>
+ values
+ case DenseVector(values) =>
+ values
+ }
+ if (!values.forall(v => v == 0.0 || v == 1.0)) {
+ throw new SparkException(
+ s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")
+ }
+ }
+
// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
createCombiner = (v: Vector) => {
- requireNonnegativeValues(v)
+ if (modelType == "Bernoulli") {
+ requireZeroOneBernoulliValues(v)
+ } else {
+ requireNonnegativeValues(v)
+ }
(1L, v.toBreeze.toDenseVector)
},
mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index ea89b17b7c08f..40a79a1f19bd9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -208,6 +208,39 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
}
}
+ test("detect non zero or one values in Bernoulli") {
+ val badTrain = Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0)))
+
+ intercept[SparkException] {
+ NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli")
+ }
+
+ val okTrain = Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0))
+ )
+
+ val badPredict = Seq(
+ Vectors.dense(1.0),
+ Vectors.dense(2.0),
+ Vectors.dense(1.0),
+ Vectors.dense(0.0))
+
+ val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli")
+ intercept[SparkException] {
+ model.predict(sc.makeRDD(badPredict, 2)).collect()
+ }
+ }
+
test("model save/load: 2.0 to 2.0") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
diff --git a/pom.xml b/pom.xml
index cf9279ea5a2a6..564a443466e5a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -669,7 +669,7 @@
org.mockito
mockito-all
- 1.9.0
+ 1.9.5
test
@@ -684,6 +684,18 @@
4.10
test
+
+ org.hamcrest
+ hamcrest-core
+ 1.3
+ test
+
+
+ org.hamcrest
+ hamcrest-library
+ 1.3
+ test
+
com.novocode
junit-interface
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index f31f0e554eee9..487062a31f77f 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -123,11 +123,20 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.parquet.ParquetTestData$"),
ProblemFilters.exclude[MissingClassProblem](
- "org.apache.spark.sql.parquet.TestGroupWriteSupport")
+ "org.apache.spark.sql.parquet.TestGroupWriteSupport"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager")
) ++ Seq(
// SPARK-7530 Added StreamingContext.getState()
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.streaming.StreamingContext.state_=")
+ ) ++ Seq(
+ // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some
+ // unnecessary type bounds in order to fix some compiler warnings that occurred when
+ // implementing this interface in Java. Note that ShuffleWriter is private[spark].
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem](
+ "org.apache.spark.shuffle.ShuffleWriter")
)
case v if v.startsWith("1.3") =>
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 8a009c4ac721f..96d29058a3781 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -17,17 +17,19 @@
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel
-from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
- HasRegParam
+from pyspark.ml.param.shared import *
+from pyspark.ml.regression import RandomForestParams
from pyspark.mllib.common import inherit_doc
-__all__ = ['LogisticRegression', 'LogisticRegressionModel']
+__all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier',
+ 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel',
+ 'RandomForestClassifier', 'RandomForestClassificationModel']
@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
- HasRegParam):
+ HasRegParam, HasTol, HasProbabilityCol):
"""
Logistic regression.
@@ -50,25 +52,49 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
TypeError: Method setParams forces keyword arguments.
"""
_java_class = "org.apache.spark.ml.classification.LogisticRegression"
+ # a placeholder to make it appear in the generated doc
+ elasticNetParam = \
+ Param(Params._dummy(), "elasticNetParam",
+ "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
+ "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
+ fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.")
+ threshold = Param(Params._dummy(), "threshold",
+ "threshold in binary classification prediction, in range [0, 1].")
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
- maxIter=100, regParam=0.1):
+ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
+ threshold=0.5, probabilityCol="probability"):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
- maxIter=100, regParam=0.1)
+ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
+ threshold=0.5, probabilityCol="probability")
"""
super(LogisticRegression, self).__init__()
- self._setDefault(maxIter=100, regParam=0.1)
+ #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
+ # is an L2 penalty. For alpha = 1, it is an L1 penalty.
+ self.elasticNetParam = \
+ Param(self, "elasticNetParam",
+ "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " +
+ "is an L2 penalty. For alpha = 1, it is an L1 penalty.")
+ #: param for whether to fit an intercept term.
+ self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
+ #: param for threshold in binary classification prediction, in range [0, 1].
+ self.threshold = Param(self, "threshold",
+ "threshold in binary classification prediction, in range [0, 1].")
+ self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6,
+ fitIntercept=True, threshold=0.5)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
- maxIter=100, regParam=0.1):
+ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
+ threshold=0.5, probabilityCol="probability"):
"""
- setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
- maxIter=100, regParam=0.1)
+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
+ threshold=0.5, probabilityCol="probability")
Sets params for logistic regression.
"""
kwargs = self.setParams._input_kwargs
@@ -77,6 +103,45 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return LogisticRegressionModel(java_model)
+ def setElasticNetParam(self, value):
+ """
+ Sets the value of :py:attr:`elasticNetParam`.
+ """
+ self.paramMap[self.elasticNetParam] = value
+ return self
+
+ def getElasticNetParam(self):
+ """
+ Gets the value of elasticNetParam or its default value.
+ """
+ return self.getOrDefault(self.elasticNetParam)
+
+ def setFitIntercept(self, value):
+ """
+ Sets the value of :py:attr:`fitIntercept`.
+ """
+ self.paramMap[self.fitIntercept] = value
+ return self
+
+ def getFitIntercept(self):
+ """
+ Gets the value of fitIntercept or its default value.
+ """
+ return self.getOrDefault(self.fitIntercept)
+
+ def setThreshold(self, value):
+ """
+ Sets the value of :py:attr:`threshold`.
+ """
+ self.paramMap[self.threshold] = value
+ return self
+
+ def getThreshold(self):
+ """
+ Gets the value of threshold or its default value.
+ """
+ return self.getOrDefault(self.threshold)
+
class LogisticRegressionModel(JavaModel):
"""
@@ -84,6 +149,399 @@ class LogisticRegressionModel(JavaModel):
"""
+class TreeClassifierParams(object):
+ """
+ Private class to track supported impurity measures.
+ """
+ supportedImpurities = ["entropy", "gini"]
+
+
+class GBTParams(object):
+ """
+ Private class to track supported GBT params.
+ """
+ supportedLossTypes = ["logistic"]
+
+
+@inherit_doc
+class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
+ DecisionTreeParams, HasCheckpointInterval):
+ """
+ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
+ learning algorithm for classification.
+ It supports both binary and multiclass labels, as well as both continuous and categorical
+ features.
+
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> from pyspark.ml.feature import StringIndexer
+ >>> df = sqlContext.createDataFrame([
+ ... (1.0, Vectors.dense(1.0)),
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
+ >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
+ >>> si_model = stringIndexer.fit(df)
+ >>> td = si_model.transform(df)
+ >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
+ >>> model = dt.fit(td)
+ >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
+ >>> model.transform(test0).head().prediction
+ 0.0
+ >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
+ >>> model.transform(test1).head().prediction
+ 1.0
+ """
+
+ _java_class = "org.apache.spark.ml.classification.DecisionTreeClassifier"
+ # a placeholder to make it appear in the generated doc
+ impurity = Param(Params._dummy(), "impurity",
+ "Criterion used for information gain calculation (case-insensitive). " +
+ "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
+
+ @keyword_only
+ def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"):
+ """
+ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
+ """
+ super(DecisionTreeClassifier, self).__init__()
+ #: param for Criterion used for information gain calculation (case-insensitive).
+ self.impurity = \
+ Param(self, "impurity",
+ "Criterion used for information gain calculation (case-insensitive). " +
+ "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
+ self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
+ impurity="gini")
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
+ impurity="gini"):
+ """
+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
+ impurity="gini")
+ Sets params for the DecisionTreeClassifier.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def _create_model(self, java_model):
+ return DecisionTreeClassificationModel(java_model)
+
+ def setImpurity(self, value):
+ """
+ Sets the value of :py:attr:`impurity`.
+ """
+ self.paramMap[self.impurity] = value
+ return self
+
+ def getImpurity(self):
+ """
+ Gets the value of impurity or its default value.
+ """
+ return self.getOrDefault(self.impurity)
+
+
+class DecisionTreeClassificationModel(JavaModel):
+ """
+ Model fitted by DecisionTreeClassifier.
+ """
+
+
+@inherit_doc
+class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
+ DecisionTreeParams, HasCheckpointInterval):
+ """
+ `http://en.wikipedia.org/wiki/Random_forest Random Forest`
+ learning algorithm for classification.
+ It supports both binary and multiclass labels, as well as both continuous and categorical
+ features.
+
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> from pyspark.ml.feature import StringIndexer
+ >>> df = sqlContext.createDataFrame([
+ ... (1.0, Vectors.dense(1.0)),
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
+ >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
+ >>> si_model = stringIndexer.fit(df)
+ >>> td = si_model.transform(df)
+ >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed")
+ >>> model = rf.fit(td)
+ >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
+ >>> model.transform(test0).head().prediction
+ 0.0
+ >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
+ >>> model.transform(test1).head().prediction
+ 1.0
+ """
+
+ _java_class = "org.apache.spark.ml.classification.RandomForestClassifier"
+ # a placeholder to make it appear in the generated doc
+ impurity = Param(Params._dummy(), "impurity",
+ "Criterion used for information gain calculation (case-insensitive). " +
+ "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
+ subsamplingRate = Param(Params._dummy(), "subsamplingRate",
+ "Fraction of the training data used for learning each decision tree, " +
+ "in range (0, 1].")
+ numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1)")
+ featureSubsetStrategy = \
+ Param(Params._dummy(), "featureSubsetStrategy",
+ "The number of features to consider for splits at each tree node. Supported " +
+ "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies))
+
+ @keyword_only
+ def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
+ numTrees=20, featureSubsetStrategy="auto", seed=42):
+ """
+ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
+ numTrees=20, featureSubsetStrategy="auto", seed=42)
+ """
+ super(RandomForestClassifier, self).__init__()
+ #: param for Criterion used for information gain calculation (case-insensitive).
+ self.impurity = \
+ Param(self, "impurity",
+ "Criterion used for information gain calculation (case-insensitive). " +
+ "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
+ #: param for Fraction of the training data used for learning each decision tree,
+ # in range (0, 1]
+ self.subsamplingRate = Param(self, "subsamplingRate",
+ "Fraction of the training data used for learning each " +
+ "decision tree, in range (0, 1].")
+ #: param for Number of trees to train (>= 1)
+ self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1)")
+ #: param for The number of features to consider for splits at each tree node
+ self.featureSubsetStrategy = \
+ Param(self, "featureSubsetStrategy",
+ "The number of features to consider for splits at each tree node. Supported " +
+ "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies))
+ self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42,
+ impurity="gini", numTrees=20, featureSubsetStrategy="auto")
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42,
+ impurity="gini", numTrees=20, featureSubsetStrategy="auto"):
+ """
+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42,
+ impurity="gini", numTrees=20, featureSubsetStrategy="auto")
+ Sets params for linear classification.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def _create_model(self, java_model):
+ return RandomForestClassificationModel(java_model)
+
+ def setImpurity(self, value):
+ """
+ Sets the value of :py:attr:`impurity`.
+ """
+ self.paramMap[self.impurity] = value
+ return self
+
+ def getImpurity(self):
+ """
+ Gets the value of impurity or its default value.
+ """
+ return self.getOrDefault(self.impurity)
+
+ def setSubsamplingRate(self, value):
+ """
+ Sets the value of :py:attr:`subsamplingRate`.
+ """
+ self.paramMap[self.subsamplingRate] = value
+ return self
+
+ def getSubsamplingRate(self):
+ """
+ Gets the value of subsamplingRate or its default value.
+ """
+ return self.getOrDefault(self.subsamplingRate)
+
+ def setNumTrees(self, value):
+ """
+ Sets the value of :py:attr:`numTrees`.
+ """
+ self.paramMap[self.numTrees] = value
+ return self
+
+ def getNumTrees(self):
+ """
+ Gets the value of numTrees or its default value.
+ """
+ return self.getOrDefault(self.numTrees)
+
+ def setFeatureSubsetStrategy(self, value):
+ """
+ Sets the value of :py:attr:`featureSubsetStrategy`.
+ """
+ self.paramMap[self.featureSubsetStrategy] = value
+ return self
+
+ def getFeatureSubsetStrategy(self):
+ """
+ Gets the value of featureSubsetStrategy or its default value.
+ """
+ return self.getOrDefault(self.featureSubsetStrategy)
+
+
+class RandomForestClassificationModel(JavaModel):
+ """
+ Model fitted by RandomForestClassifier.
+ """
+
+
+@inherit_doc
+class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
+ DecisionTreeParams, HasCheckpointInterval):
+ """
+ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)`
+ learning algorithm for classification.
+ It supports binary labels, as well as both continuous and categorical features.
+ Note: Multiclass labels are not currently supported.
+
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> from pyspark.ml.feature import StringIndexer
+ >>> df = sqlContext.createDataFrame([
+ ... (1.0, Vectors.dense(1.0)),
+ ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
+ >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
+ >>> si_model = stringIndexer.fit(df)
+ >>> td = si_model.transform(df)
+ >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed")
+ >>> model = gbt.fit(td)
+ >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
+ >>> model.transform(test0).head().prediction
+ 0.0
+ >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
+ >>> model.transform(test1).head().prediction
+ 1.0
+ """
+
+ _java_class = "org.apache.spark.ml.classification.GBTClassifier"
+ # a placeholder to make it appear in the generated doc
+ lossType = Param(Params._dummy(), "lossType",
+ "Loss function which GBT tries to minimize (case-insensitive). " +
+ "Supported options: " + ", ".join(GBTParams.supportedLossTypes))
+ subsamplingRate = Param(Params._dummy(), "subsamplingRate",
+ "Fraction of the training data used for learning each decision tree, " +
+ "in range (0, 1].")
+ stepSize = Param(Params._dummy(), "stepSize",
+ "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the " +
+ "contribution of each estimator")
+
+ @keyword_only
+ def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
+ maxIter=20, stepSize=0.1):
+ """
+ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
+ maxIter=20, stepSize=0.1)
+ """
+ super(GBTClassifier, self).__init__()
+ #: param for Loss function which GBT tries to minimize (case-insensitive).
+ self.lossType = Param(self, "lossType",
+ "Loss function which GBT tries to minimize (case-insensitive). " +
+ "Supported options: " + ", ".join(GBTParams.supportedLossTypes))
+ #: Fraction of the training data used for learning each decision tree, in range (0, 1].
+ self.subsamplingRate = Param(self, "subsamplingRate",
+ "Fraction of the training data used for learning each " +
+ "decision tree, in range (0, 1].")
+ #: Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of
+ # each estimator
+ self.stepSize = Param(self, "stepSize",
+ "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
+ "the contribution of each estimator")
+ self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
+ lossType="logistic", maxIter=20, stepSize=0.1)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
+ lossType="logistic", maxIter=20, stepSize=0.1):
+ """
+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
+ lossType="logistic", maxIter=20, stepSize=0.1)
+ Sets params for Gradient Boosted Tree Classification.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def _create_model(self, java_model):
+ return GBTClassificationModel(java_model)
+
+ def setLossType(self, value):
+ """
+ Sets the value of :py:attr:`lossType`.
+ """
+ self.paramMap[self.lossType] = value
+ return self
+
+ def getLossType(self):
+ """
+ Gets the value of lossType or its default value.
+ """
+ return self.getOrDefault(self.lossType)
+
+ def setSubsamplingRate(self, value):
+ """
+ Sets the value of :py:attr:`subsamplingRate`.
+ """
+ self.paramMap[self.subsamplingRate] = value
+ return self
+
+ def getSubsamplingRate(self):
+ """
+ Gets the value of subsamplingRate or its default value.
+ """
+ return self.getOrDefault(self.subsamplingRate)
+
+ def setStepSize(self, value):
+ """
+ Sets the value of :py:attr:`stepSize`.
+ """
+ self.paramMap[self.stepSize] = value
+ return self
+
+ def getStepSize(self):
+ """
+ Gets the value of stepSize or its default value.
+ """
+ return self.getOrDefault(self.stepSize)
+
+
+class GBTClassificationModel(JavaModel):
+ """
+ Model fitted by GBTClassifier.
+ """
+
+
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 4a5cc6e64f023..6fa9b8c2cf367 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -109,6 +109,9 @@ def get$Name(self):
("featuresCol", "features column name", "'features'"),
("labelCol", "label column name", "'label'"),
("predictionCol", "prediction column name", "'prediction'"),
+ ("probabilityCol", "Column name for predicted class conditional probabilities. " +
+ "Note: Not all models output well-calibrated probability estimates! These probabilities " +
+ "should be treated as confidences, not precise probabilities.", "'probability'"),
("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", "'rawPrediction'"),
("inputCol", "input column name", None),
("inputCols", "input column names", None),
@@ -156,6 +159,7 @@ def __init__(self):
for name, doc in decisionTreeParams:
variable = paramTemplate.replace("$name", name).replace("$doc", doc)
dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n "
+ realParams += "#: param for " + doc + "\n "
realParams += "self." + variable.replace("$owner", "self") + "\n "
dtParamMethods += _gen_param_code(name, doc, None) + "\n"
code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders)
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 779cabe853f8e..b116f05a068d3 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -165,6 +165,35 @@ def getPredictionCol(self):
return self.getOrDefault(self.predictionCol)
+class HasProbabilityCol(Params):
+ """
+ Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities..
+ """
+
+ # a placeholder to make it appear in the generated doc
+ probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.")
+
+ def __init__(self):
+ super(HasProbabilityCol, self).__init__()
+ #: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.
+ self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.")
+ if 'probability' is not None:
+ self._setDefault(probabilityCol='probability')
+
+ def setProbabilityCol(self, value):
+ """
+ Sets the value of :py:attr:`probabilityCol`.
+ """
+ self.paramMap[self.probabilityCol] = value
+ return self
+
+ def getProbabilityCol(self):
+ """
+ Gets the value of probabilityCol or its default value.
+ """
+ return self.getOrDefault(self.probabilityCol)
+
+
class HasRawPredictionCol(Params):
"""
Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index a13e2f36a1a1f..75a493b248f6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -23,6 +23,7 @@ import java.util.{Map => JavaMap}
import scala.collection.mutable.HashMap
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
similarity index 99%
rename from sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 1ec874f79617c..625c8d3a62125 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql
+package org.apache.spark.sql.catalyst
import java.beans.Introspector
import java.lang.{Iterable => JIterable}
@@ -24,10 +24,8 @@ import java.util.{Iterator => JIterator, Map => JMap}
import scala.language.existentials
import com.google.common.reflect.TypeToken
-
import org.apache.spark.sql.types._
-
/**
* Type-inference utilities for POJOs and Java collections.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala
index 05a92b06f9fd9..554fb4eb25eb1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala
@@ -31,3 +31,39 @@ abstract class ParserDialect {
// this is the main function that will be implemented by sql parser.
def parse(sqlText: String): LogicalPlan
}
+
+/**
+ * Currently we support the default dialect named "sql", associated with the class
+ * [[DefaultParserDialect]]
+ *
+ * And we can also provide custom SQL Dialect, for example in Spark SQL CLI:
+ * {{{
+ *-- switch to "hiveql" dialect
+ * spark-sql>SET spark.sql.dialect=hiveql;
+ * spark-sql>SELECT * FROM src LIMIT 1;
+ *
+ *-- switch to "sql" dialect
+ * spark-sql>SET spark.sql.dialect=sql;
+ * spark-sql>SELECT * FROM src LIMIT 1;
+ *
+ *-- register the new SQL dialect
+ * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect;
+ * spark-sql> SELECT * FROM src LIMIT 1;
+ *
+ *-- register the non-exist SQL dialect
+ * spark-sql> SET spark.sql.dialect=NotExistedClass;
+ * spark-sql> SELECT * FROM src LIMIT 1;
+ *
+ *-- Exception will be thrown and switch to dialect
+ *-- "sql" (for SQLContext) or
+ *-- "hiveql" (for HiveContext)
+ * }}}
+ */
+private[spark] class DefaultParserDialect extends ParserDialect {
+ @transient
+ protected val sqlParser = new SqlParser
+
+ override def parse(sqlText: String): LogicalPlan = {
+ sqlParser.parse(sqlText)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index adf941ab2a45f..d8cf2b2e32435 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
/** Cast the child expression to the target data type. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index d17af0e7ff87e..ecb4c4b68f904 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -250,7 +250,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case Cast(child @ DateType(), StringType) =>
child.castOrNull(c =>
q"""org.apache.spark.sql.types.UTF8String(
- org.apache.spark.sql.types.DateUtils.toString($c))""",
+ org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""",
StringType)
case Cast(child @ NumericType(), IntegerType) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 18cba4cc46707..5f8c7354aede1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
object Literal {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b163707cc9925..c2818d957cc79 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -156,6 +156,11 @@ object ColumnPruning extends Rule[LogicalPlan] {
case Project(projectList, Limit(exp, child)) =>
Limit(exp, Project(projectList, child))
+ // push down project if possible when the child is sort
+ case p @ Project(projectList, s @ Sort(_, _, grandChild))
+ if s.references.subsetOf(p.outputSet) =>
+ s.copy(child = Project(projectList, grandChild))
+
// Eliminate no-op Projects
case Project(projectList, child) if child.output == projectList => child
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala
similarity index 98%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala
index d36a49159b87f..3f92be4a55d7d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.types
+package org.apache.spark.sql.catalyst.util
import java.sql.Date
import java.text.SimpleDateFormat
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
index fc02ba6c9c43e..bc9c37bf2d5d2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
@@ -19,15 +19,18 @@ package org.apache.spark.sql.types
import java.util.Arrays
+import org.apache.spark.annotation.DeveloperApi
+
/**
- * A UTF-8 String, as internal representation of StringType in SparkSQL
+ * :: DeveloperApi ::
+ * A UTF-8 String, as internal representation of StringType in SparkSQL
*
- * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison,
- * search, see http://en.wikipedia.org/wiki/UTF-8 for details.
+ * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison,
+ * search, see http://en.wikipedia.org/wiki/UTF-8 for details.
*
- * Note: This is not designed for general use cases, should not be used outside SQL.
+ * Note: This is not designed for general use cases, should not be used outside SQL.
*/
-
+@DeveloperApi
final class UTF8String extends Ordered[UTF8String] with Serializable {
private[this] var bytes: Array[Byte] = _
@@ -180,6 +183,10 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
}
}
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
object UTF8String {
// number of tailing bytes in a UTF8 sequence for a code point
// see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 04fd261d16aa3..5c4a1527c27c9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.mathfuncs._
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 0c428f7231b8e..be33cb9bb8eaa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
-import org.apache.spark.sql.catalyst.expressions.{Count, Explode}
+import org.apache.spark.sql.catalyst.expressions.{SortOrder, Ascending, Count, Explode}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.rules._
@@ -542,4 +542,38 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery)
}
+
+ test("push down project past sort") {
+ val x = testRelation.subquery('x)
+
+ // push down valid
+ val originalQuery = {
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('a)
+ }
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ x.select('a)
+ .sortBy(SortOrder('a, Ascending)).analyze
+
+ comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer))
+
+ // push down invalid
+ val originalQuery1 = {
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('b)
+ }
+
+ val optimized1 = Optimize.execute(originalQuery1.analyze)
+ val correctAnswer1 =
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('b).analyze
+
+ comparePlans(optimized1, analysis.EliminateSubQueries(correctAnswer1))
+
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 42f5bcda49cfb..8bf1320ccb71d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -346,6 +346,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* }}}
*
* @group expr_ops
+ * @since 1.4.0
*/
def when(condition: Column, value: Any):Column = this.expr match {
case CaseWhen(branches: Seq[Expression]) =>
@@ -374,6 +375,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* }}}
*
* @group expr_ops
+ * @since 1.4.0
*/
def otherwise(value: Any):Column = this.expr match {
case CaseWhen(branches: Seq[Expression]) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 0a148c7cd2d3b..521f3dc821795 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -33,6 +33,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.errors.DialectException
@@ -40,7 +41,6 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.ParserDialect
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, expressions}
import org.apache.spark.sql.execution.{Filter, _}
import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.json._
@@ -50,42 +50,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.{Partition, SparkContext}
-/**
- * Currently we support the default dialect named "sql", associated with the class
- * [[DefaultParserDialect]]
- *
- * And we can also provide custom SQL Dialect, for example in Spark SQL CLI:
- * {{{
- *-- switch to "hiveql" dialect
- * spark-sql>SET spark.sql.dialect=hiveql;
- * spark-sql>SELECT * FROM src LIMIT 1;
- *
- *-- switch to "sql" dialect
- * spark-sql>SET spark.sql.dialect=sql;
- * spark-sql>SELECT * FROM src LIMIT 1;
- *
- *-- register the new SQL dialect
- * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect;
- * spark-sql> SELECT * FROM src LIMIT 1;
- *
- *-- register the non-exist SQL dialect
- * spark-sql> SET spark.sql.dialect=NotExistedClass;
- * spark-sql> SELECT * FROM src LIMIT 1;
- *
- *-- Exception will be thrown and switch to dialect
- *-- "sql" (for SQLContext) or
- *-- "hiveql" (for HiveContext)
- * }}}
- */
-private[spark] class DefaultParserDialect extends ParserDialect {
- @transient
- protected val sqlParser = new catalyst.SqlParser
-
- override def parse(sqlText: String): LogicalPlan = {
- sqlParser.parse(sqlText)
- }
-}
-
/**
* The entry point for working with structured data (rows and columns) in Spark. Allows the
* creation of [[DataFrame]] objects as well as the execution of SQL queries.
@@ -1276,7 +1240,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val projectSet = AttributeSet(projectList.flatMap(_.references))
val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
val filterCondition =
- prunePushedDownFilters(filterPredicates).reduceLeftOption(expressions.And)
+ prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And)
// Right now we still use a projection even if the only evaluation is applying an alias
// to a column. Since this is a no-op, it could be avoided. However, using this
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
similarity index 97%
rename from sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 18584c2dcf797..5fcc48a67948b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -15,18 +15,19 @@
* limitations under the License.
*/
-package org.apache.spark.sql
+package org.apache.spark.sql.execution
import java.util.concurrent.locks.ReentrantReadWriteLock
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.columnar.InMemoryRelation
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
/** Holds a cached logical plan and its data */
-private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
+private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
/**
* Provides support in a SQLContext for caching query results and automatically using these cached
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index c3d2c7019a54a..3e46596ecf6ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -17,17 +17,18 @@
package org.apache.spark.sql.execution
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.serializer.Serializer
-import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.util.MutablePair
object Exchange {
@@ -85,7 +86,9 @@ case class Exchange(
// corner-cases where a partitioner constructed with `numPartitions` partitions may output
// fewer partitions (like RangePartitioner, for example).
val conf = child.sqlContext.sparkContext.conf
- val sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+ val shuffleManager = SparkEnv.get.shuffleManager
+ val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] ||
+ shuffleManager.isInstanceOf[UnsafeShuffleManager]
val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true)
if (newOrdering.nonEmpty) {
@@ -93,11 +96,11 @@ case class Exchange(
// which requires a defensive copy.
true
} else if (sortBasedShuffleOn) {
- // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory.
- // However, there are two special cases where we can avoid the copy, described below:
- if (partitioner.numPartitions <= bypassMergeThreshold) {
- // If the number of output partitions is sufficiently small, then Spark will fall back to
- // the old hash-based shuffle write path which doesn't buffer deserialized records.
+ val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+ if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) {
+ // If we're using the original SortShuffleManager and the number of output partitions is
+ // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which
+ // doesn't buffer deserialized records.
// Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
false
} else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) {
@@ -105,9 +108,14 @@ case class Exchange(
// them. This optimization is guarded by a feature-flag and is only applied in cases where
// shuffle dependency does not specify an ordering and the record serializer has certain
// properties. If this optimization is enabled, we can safely avoid the copy.
+ //
+ // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081).
false
} else {
- // None of the special cases held, so we must copy.
+ // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code
+ // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls
+ // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In
+ // both cases, we must copy.
true
}
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 3dbc3837950e0..65dd7ba020fa3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -19,20 +19,21 @@ package org.apache.spark.sql.execution
import java.util.{List => JList, Map => JMap}
-import org.apache.spark.rdd.RDD
-
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import net.razorvine.pickle.{Pickler, Unpickler}
+
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.{PythonBroadcast, PythonRDD}
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
import org.apache.spark.{Accumulator, Logging => SparkLogging}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 099e1d8f03272..4404ad8ad63a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -438,6 +438,7 @@ object functions {
* }}}
*
* @group normal_funcs
+ * @since 1.4.0
*/
def when(condition: Column, value: Any): Column = {
CaseWhen(Seq(condition.expr, lit(value).expr))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index a03ade3881f59..40483d3ec7701 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -25,9 +25,9 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow}
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources._
-import org.apache.spark.util.Utils
private[sql] object JDBCRDD extends Logging {
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
index d6b3fb3291a2e..93e82549f213b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.sources._
@@ -129,7 +130,8 @@ private[sql] case class JDBCRelation(
parts: Array[Partition],
properties: Properties = new Properties())(@transient val sqlContext: SQLContext)
extends BaseRelation
- with PrunedFilteredScan {
+ with PrunedFilteredScan
+ with InsertableRelation {
override val needConversion: Boolean = false
@@ -148,4 +150,8 @@ private[sql] case class JDBCRelation(
filters,
parts)
}
+
+ override def insert(data: DataFrame, overwrite: Boolean): Unit = {
+ data.insertIntoJDBC(url, table, overwrite, properties)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
index a8e69ae61174f..81611513582a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
@@ -26,6 +26,7 @@ import com.fasterxml.jackson.core._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index f62973d5fcfab..4c32710a17bc7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
import org.apache.spark.Logging
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index ec0e76cde6f7c..8cdbe076cbd85 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -19,10 +19,10 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index f3ce8e66460e5..0800eded443de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -43,6 +43,29 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
conn1 = DriverManager.getConnection(url1, properties)
conn1.prepareStatement("create schema test").executeUpdate()
+ conn1.prepareStatement("drop table if exists test.people").executeUpdate()
+ conn1.prepareStatement(
+ "create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
+ conn1.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate()
+ conn1.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate()
+ conn1.prepareStatement("drop table if exists test.people1").executeUpdate()
+ conn1.prepareStatement(
+ "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
+ conn1.commit()
+
+ TestSQLContext.sql(
+ s"""
+ |CREATE TEMPORARY TABLE PEOPLE
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
+ """.stripMargin.replaceAll("\n", " "))
+
+ TestSQLContext.sql(
+ s"""
+ |CREATE TEMPORARY TABLE PEOPLE1
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass')
+ """.stripMargin.replaceAll("\n", " "))
}
after {
@@ -114,5 +137,17 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
df2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true)
}
}
-
+
+ test("INSERT to JDBC Datasource") {
+ TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
+ }
+
+ test("INSERT to JDBC Datasource with overwrite") {
+ TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 263fafba930ce..b06e3385980f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -24,6 +24,7 @@ import com.fasterxml.jackson.core.JsonFactory
import org.scalactic.Tolerance._
import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.functions._
import org.apache.spark.sql.json.InferSchema.compatibleType
import org.apache.spark.sql.sources.LogicalRelation
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index 7c371dbc7d3c9..008443df216aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -35,6 +35,7 @@ import parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 74ae984f34866..7c7666f6e4b7c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types
import org.apache.spark.sql.types._
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index b69312f0f8717..0b6f7a334a715 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -35,7 +35,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.Logging
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.DateUtils
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.util.Utils
/**
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 1d6393a3fec85..eaa9d6aad1f31 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.hive.execution
+import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.errors.DialectException
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
@@ -26,7 +28,6 @@ import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation}
import org.apache.spark.sql.parquet.FSBasedParquetRelation
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{AnalysisException, DefaultParserDialect, QueryTest, Row, SQLConf}
case class Nested1(f1: Nested2)
case class Nested2(f2: Nested3)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 7f747a8bd4712..b49b3ad7d6f3a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -134,7 +134,7 @@ class StreamingContext private[streaming] (
if (sc_ != null) {
sc_
} else if (isCheckpointPresent) {
- new SparkContext(cp_.createSparkConf())
+ SparkContext.getOrCreate(cp_.createSparkConf())
} else {
throw new SparkException("Cannot create StreamingContext without a SparkContext")
}
@@ -759,53 +759,6 @@ object StreamingContext extends Logging {
checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc())
}
- /**
- * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
- * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
- * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
- * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note
- * that the SparkConf configuration in the checkpoint data will not be restored as the
- * SparkContext has already been created.
- *
- * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
- * @param creatingFunc Function to create a new StreamingContext using the given SparkContext
- * @param sparkContext SparkContext using which the StreamingContext will be created
- */
- def getOrCreate(
- checkpointPath: String,
- creatingFunc: SparkContext => StreamingContext,
- sparkContext: SparkContext
- ): StreamingContext = {
- getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false)
- }
-
- /**
- * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
- * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
- * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
- * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note
- * that the SparkConf configuration in the checkpoint data will not be restored as the
- * SparkContext has already been created.
- *
- * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
- * @param creatingFunc Function to create a new StreamingContext using the given SparkContext
- * @param sparkContext SparkContext using which the StreamingContext will be created
- * @param createOnError Whether to create a new StreamingContext if there is an
- * error in reading checkpoint data. By default, an exception will be
- * thrown on error.
- */
- def getOrCreate(
- checkpointPath: String,
- creatingFunc: SparkContext => StreamingContext,
- sparkContext: SparkContext,
- createOnError: Boolean
- ): StreamingContext = {
- val checkpointOption = CheckpointReader.read(
- checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError)
- checkpointOption.map(new StreamingContext(sparkContext, _, null))
- .getOrElse(creatingFunc(sparkContext))
- }
-
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
* their JARs to StreamingContext.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index d8fbed2c50644..b639b94d5ca47 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -804,51 +804,6 @@ object JavaStreamingContext {
new JavaStreamingContext(ssc)
}
- /**
- * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
- * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
- * recreated from the checkpoint data. If the data does not exist, then the provided factory
- * will be used to create a JavaStreamingContext.
- *
- * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
- * @param creatingFunc Function to create a new JavaStreamingContext
- * @param sparkContext SparkContext using which the StreamingContext will be created
- */
- def getOrCreate(
- checkpointPath: String,
- creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext],
- sparkContext: JavaSparkContext
- ): JavaStreamingContext = {
- val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => {
- creatingFunc.call(new JavaSparkContext(sparkContext)).ssc
- }, sparkContext.sc)
- new JavaStreamingContext(ssc)
- }
-
- /**
- * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
- * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
- * recreated from the checkpoint data. If the data does not exist, then the provided factory
- * will be used to create a JavaStreamingContext.
- *
- * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
- * @param creatingFunc Function to create a new JavaStreamingContext
- * @param sparkContext SparkContext using which the StreamingContext will be created
- * @param createOnError Whether to create a new JavaStreamingContext if there is an
- * error in reading checkpoint data.
- */
- def getOrCreate(
- checkpointPath: String,
- creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext],
- sparkContext: JavaSparkContext,
- createOnError: Boolean
- ): JavaStreamingContext = {
- val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => {
- creatingFunc.call(new JavaSparkContext(sparkContext)).ssc
- }, sparkContext.sc, createOnError)
- new JavaStreamingContext(ssc)
- }
-
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
* their JARs to StreamingContext.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala
index c206f973b2c66..f153ee105a18e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala
@@ -19,7 +19,7 @@ package org.apache.spark.streaming.ui
import java.util.concurrent.TimeUnit
-object UIUtils {
+private[streaming] object UIUtils {
/**
* Return the short string for a `TimeUnit`.
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 2e00b980b9e44..1077b1b2cb7e3 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -1766,29 +1766,10 @@ public JavaStreamingContext call() {
Assert.assertTrue("old context not recovered", !newContextCreated.get());
ssc.stop();
- // Function to create JavaStreamingContext using existing JavaSparkContext
- // without any output operations (used to detect the new context)
- Function creatingFunc2 =
- new Function() {
- public JavaStreamingContext call(JavaSparkContext context) {
- newContextCreated.set(true);
- return new JavaStreamingContext(context, Seconds.apply(1));
- }
- };
-
- JavaSparkContext sc = new JavaSparkContext(conf);
- newContextCreated.set(false);
- ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc);
- Assert.assertTrue("new context not created", newContextCreated.get());
- ssc.stop(false);
-
newContextCreated.set(false);
- ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true);
- Assert.assertTrue("new context not created", newContextCreated.get());
- ssc.stop(false);
-
- newContextCreated.set(false);
- ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc);
+ JavaSparkContext sc = new JavaSparkContext(conf);
+ ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc,
+ new org.apache.hadoop.conf.Configuration());
Assert.assertTrue("old context not recovered", !newContextCreated.get());
ssc.stop();
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 5f93332896de1..4b12affbb0ddd 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -419,76 +419,16 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _)
assert(ssc != null, "no context created")
assert(!newContextCreated, "old context not recovered")
- assert(ssc.conf.get("someKey") === "someValue")
- }
- }
-
- test("getOrCreate with existing SparkContext") {
- val conf = new SparkConf().setMaster(master).setAppName(appName)
- sc = new SparkContext(conf)
-
- // Function to create StreamingContext that has a config to identify it to be new context
- var newContextCreated = false
- def creatingFunction(sparkContext: SparkContext): StreamingContext = {
- newContextCreated = true
- new StreamingContext(sparkContext, batchDuration)
- }
-
- // Call ssc.stop(stopSparkContext = false) after a body of cody
- def testGetOrCreate(body: => Unit): Unit = {
- newContextCreated = false
- try {
- body
- } finally {
- if (ssc != null) {
- ssc.stop(stopSparkContext = false)
- }
- ssc = null
- }
- }
-
- val emptyPath = Utils.createTempDir().getAbsolutePath()
-
- // getOrCreate should create new context with empty path
- testGetOrCreate {
- ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true)
- assert(ssc != null, "no context created")
- assert(newContextCreated, "new context not created")
- assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext")
+ assert(ssc.conf.get("someKey") === "someValue", "checkpointed config not recovered")
}
- val corrutedCheckpointPath = createCorruptedCheckpoint()
-
- // getOrCreate should throw exception with fake checkpoint file and createOnError = false
- intercept[Exception] {
- ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc)
- }
-
- // getOrCreate should throw exception with fake checkpoint file
- intercept[Exception] {
- ssc = StreamingContext.getOrCreate(
- corrutedCheckpointPath, creatingFunction _, sc, createOnError = false)
- }
-
- // getOrCreate should create new context with fake checkpoint file and createOnError = true
- testGetOrCreate {
- ssc = StreamingContext.getOrCreate(
- corrutedCheckpointPath, creatingFunction _, sc, createOnError = true)
- assert(ssc != null, "no context created")
- assert(newContextCreated, "new context not created")
- assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext")
- }
-
- val checkpointPath = createValidCheckpoint()
-
- // StreamingContext.getOrCreate should recover context with checkpoint path
+ // getOrCreate should recover StreamingContext with existing SparkContext
testGetOrCreate {
- ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc)
+ sc = new SparkContext(conf)
+ ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _)
assert(ssc != null, "no context created")
assert(!newContextCreated, "old context not recovered")
- assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext")
- assert(!ssc.conf.contains("someKey"),
- "recovered StreamingContext unexpectedly has old config")
+ assert(!ssc.conf.contains("someKey"), "checkpointed config unexpectedly recovered")
}
}
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index 5b0733206b2bc..9e151fc7a9141 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -42,6 +42,10 @@
com.google.code.findbugs
jsr305
+
+ com.google.guava
+ guava
+
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 9224988e6ad69..2906ac8abad1a 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -19,6 +19,7 @@
import java.util.*;
+import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -47,10 +48,18 @@ public final class TaskMemoryManager {
private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
- /**
- * The number of entries in the page table.
- */
- private static final int PAGE_TABLE_SIZE = 1 << 13;
+ /** The number of bits used to address the page table. */
+ private static final int PAGE_NUMBER_BITS = 13;
+
+ /** The number of bits used to encode offsets in data pages. */
+ @VisibleForTesting
+ static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51
+
+ /** The number of entries in the page table. */
+ private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
+
+ /** Maximum supported data page size */
+ private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS);
/** Bit mask for the lower 51 bits of a long. */
private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
@@ -101,11 +110,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) {
* intended for allocating large blocks of memory that will be shared between operators.
*/
public MemoryBlock allocatePage(long size) {
- if (logger.isTraceEnabled()) {
- logger.trace("Allocating {} byte page", size);
- }
- if (size >= (1L << 51)) {
- throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes");
+ if (size > MAXIMUM_PAGE_SIZE) {
+ throw new IllegalArgumentException(
+ "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes");
}
final int pageNumber;
@@ -120,8 +127,8 @@ public MemoryBlock allocatePage(long size) {
final MemoryBlock page = executorMemoryManager.allocate(size);
page.pageNumber = pageNumber;
pageTable[pageNumber] = page;
- if (logger.isDebugEnabled()) {
- logger.debug("Allocate page number {} ({} bytes)", pageNumber, size);
+ if (logger.isTraceEnabled()) {
+ logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
}
return page;
}
@@ -130,9 +137,6 @@ public MemoryBlock allocatePage(long size) {
* Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
*/
public void freePage(MemoryBlock page) {
- if (logger.isTraceEnabled()) {
- logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size());
- }
assert (page.pageNumber != -1) :
"Called freePage() on memory that wasn't allocated with allocatePage()";
executorMemoryManager.free(page);
@@ -140,8 +144,8 @@ public void freePage(MemoryBlock page) {
allocatedPages.clear(page.pageNumber);
}
pageTable[page.pageNumber] = null;
- if (logger.isDebugEnabled()) {
- logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size());
+ if (logger.isTraceEnabled()) {
+ logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
}
}
@@ -173,14 +177,36 @@ public void free(MemoryBlock memory) {
/**
* Given a memory page and offset within that page, encode this address into a 64-bit long.
* This address will remain valid as long as the corresponding page has not been freed.
+ *
+ * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}.
+ * @param offsetInPage an offset in this page which incorporates the base offset. In other words,
+ * this should be the value that you would pass as the base offset into an
+ * UNSAFE call (e.g. page.baseOffset() + something).
+ * @return an encoded page address.
*/
public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
- if (inHeap) {
- assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
- return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
- } else {
- return offsetInPage;
+ if (!inHeap) {
+ // In off-heap mode, an offset is an absolute address that may require a full 64 bits to
+ // encode. Due to our page size limitation, though, we can convert this into an offset that's
+ // relative to the page's base offset; this relative offset will fit in 51 bits.
+ offsetInPage -= page.getBaseOffset();
}
+ return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
+ }
+
+ @VisibleForTesting
+ public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
+ assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
+ return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
+ }
+
+ @VisibleForTesting
+ public static int decodePageNumber(long pagePlusOffsetAddress) {
+ return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS);
+ }
+
+ private static long decodeOffset(long pagePlusOffsetAddress) {
+ return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
}
/**
@@ -189,7 +215,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
*/
public Object getPage(long pagePlusOffsetAddress) {
if (inHeap) {
- final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51);
+ final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
final Object page = pageTable[pageNumber].getBaseObject();
assert (page != null);
@@ -204,10 +230,15 @@ public Object getPage(long pagePlusOffsetAddress) {
* {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
*/
public long getOffsetInPage(long pagePlusOffsetAddress) {
+ final long offsetInPage = decodeOffset(pagePlusOffsetAddress);
if (inHeap) {
- return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
+ return offsetInPage;
} else {
- return pagePlusOffsetAddress;
+ // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we
+ // converted the absolute address into a relative address. Here, we invert that operation:
+ final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
+ assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+ return pageTable[pageNumber].getBaseOffset() + offsetInPage;
}
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
index 932882f1ca248..06fb081183659 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
@@ -38,4 +38,27 @@ public void leakedPageMemoryIsDetected() {
Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
}
+ @Test
+ public void encodePageNumberAndOffsetOffHeap() {
+ final TaskMemoryManager manager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+ final MemoryBlock dataPage = manager.allocatePage(256);
+ // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
+ // encode. This test exercises that corner-case:
+ final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
+ final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset);
+ Assert.assertEquals(null, manager.getPage(encodedAddress));
+ Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress));
+ }
+
+ @Test
+ public void encodePageNumberAndOffsetOnHeap() {
+ final TaskMemoryManager manager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock dataPage = manager.allocatePage(256);
+ final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
+ Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
+ Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
+ }
+
}