diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index f301d233cb0a0..56c47d564a3b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -256,6 +256,16 @@ class SymmetricHashJoinStateManager( return null } + /** + * Find the first non-null value index starting from end + * and going up-to stopIndex. + */ + private def getRightMostNonNullIndex(stopIndex: Long): Option[Long] = { + (numValues - 1 to stopIndex by -1).find { idx => + keyWithIndexToValue.get(currentKey, idx) != null + } + } + override def getNext(): KeyToValuePair = { val currentValue = findNextValueForIndex() @@ -272,12 +282,33 @@ class SymmetricHashJoinStateManager( if (index != numValues - 1) { val valuePairAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues - 1) if (valuePairAtMaxIndex != null) { + // Likely case where last element is non-null and we can simply swap with index. keyWithIndexToValue.put(currentKey, index, valuePairAtMaxIndex.value, valuePairAtMaxIndex.matched) } else { - val projectedKey = getInternalRowOfKeyWithIndex(currentKey) - logWarning(s"`keyWithIndexToValue` returns a null value for index ${numValues - 1} " + - s"at current key $projectedKey.") + // Find the rightmost non null index and swap values with that index, + // if index returned is not the same as the passed one + val nonNullIndex = getRightMostNonNullIndex(index + 1).getOrElse(index) + if (nonNullIndex != index) { + val valuePair = keyWithIndexToValue.get(currentKey, nonNullIndex) + keyWithIndexToValue.put(currentKey, index, valuePair.value, + valuePair.matched) + } + + // If nulls were found at the end, log a warning for the range of null indices. + if (nonNullIndex != numValues - 1) { + logWarning(s"`keyWithIndexToValue` returns a null value for indices " + + s"with range from startIndex=${nonNullIndex + 1} " + + s"and endIndex=${numValues - 1}.") + } + + // Remove all null values from nonNullIndex + 1 onwards + // The nonNullIndex itself will be handled as removing the last entry, + // similar to finding the value as the last element + (numValues - 1 to nonNullIndex + 1 by -1).foreach { removeIndex => + keyWithIndexToValue.remove(currentKey, removeIndex) + numValues -= 1 + } } } keyWithIndexToValue.remove(currentKey, numValues - 1) @@ -324,6 +355,15 @@ class SymmetricHashJoinStateManager( ) } + /** + * Update number of values for a key. + * NOTE: this function is only intended for use in unit tests + * to simulate null values. + */ + private[state] def updateNumValuesTestOnly(key: UnsafeRow, numValues: Long): Unit = { + keyToNumValues.put(key, numValues) + } + /* ===================================================== Private methods and inner classes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index 8a03d46d00007..deeebe1fc42bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -46,6 +46,12 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter } } + SymmetricHashJoinStateManager.supportedVersions.foreach { version => + test(s"StreamingJoinStateManager V${version} - all operations with nulls") { + testAllOperationsWithNulls(version) + } + } + SymmetricHashJoinStateManager.supportedVersions.foreach { version => test(s"SPARK-35689: StreamingJoinStateManager V${version} - " + "printable key of keyWithIndexToValue") { @@ -68,7 +74,6 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter } } - private def testAllOperations(stateFormatVersion: Int): Unit = { withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion) { manager => implicit val mgr = manager @@ -99,11 +104,6 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter assert(get(30) === Seq.empty) // should remove 30 assert(numRows === 0) - def appendAndTest(key: Int, values: Int*): Unit = { - values.foreach { value => append(key, value)} - require(get(key) === values) - } - appendAndTest(40, 100, 200, 300) appendAndTest(50, 125) appendAndTest(60, 275) // prepare for testing removeByValue @@ -130,6 +130,43 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter assert(numRows === 0) } } + + /* Test removeByValue with nulls simulated by updating numValues on the state manager */ + private def testAllOperationsWithNulls(stateFormatVersion: Int): Unit = { + withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion) { manager => + implicit val mgr = manager + + appendAndTest(40, 100, 200, 300) + appendAndTest(50, 125) + appendAndTest(60, 275) // prepare for testing removeByValue + assert(numRows === 5) + + updateNumValues(40, 5) // update total values to 5 to create 2 nulls + removeByValue(125) + assert(get(40) === Seq(200, 300)) + assert(get(50) === Seq.empty) + assert(get(60) === Seq(275)) // should remove only some values, not all and nulls + assert(numRows === 3) + + append(40, 50) + assert(get(40) === Seq(50, 200, 300)) + assert(numRows === 4) + updateNumValues(40, 4) // update total values to 4 to create 1 null + + removeByValue(200) + assert(get(40) === Seq(300)) + assert(get(60) === Seq(275)) // should remove only some values, not all and nulls + assert(numRows === 2) + updateNumValues(40, 2) // update total values to simulate nulls + updateNumValues(60, 4) + + removeByValue(300) + assert(get(40) === Seq.empty) + assert(get(60) === Seq.empty) // should remove all values now including nulls + assert(numRows === 0) + } + } + val watermarkMetadata = new MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build() val inputValueSchema = new StructType() .add(StructField("time", IntegerType, metadata = watermarkMetadata)) @@ -157,6 +194,17 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter manager.append(toJoinKeyRow(key), toInputValue(value), matched = false) } + def appendAndTest(key: Int, values: Int*) + (implicit manager: SymmetricHashJoinStateManager): Unit = { + values.foreach { value => append(key, value)} + require(get(key) === values) + } + + def updateNumValues(key: Int, numValues: Long) + (implicit manager: SymmetricHashJoinStateManager): Unit = { + manager.updateNumValuesTestOnly(toJoinKeyRow(key), numValues) + } + def get(key: Int)(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = { manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted }