Skip to content

Commit

Permalink
[CORE-459] MapRDB spark connector shows incorrect results when Second…
Browse files Browse the repository at this point in the history
…ary Index enabled on fields (apache#694)
  • Loading branch information
vvysotskyi authored and Egor Krivokon committed Oct 26, 2021
1 parent 8200c36 commit 46c4a9e
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.mapr.db.spark.condition._
import com.mapr.db.spark.configuration.SerializableConfiguration
import com.mapr.db.spark.dbclient.DBClient
import com.mapr.db.spark.impl.OJAIDocument
import com.mapr.db.spark.sql.SingleFragmentOption
import com.mapr.db.spark.utils.DefaultClass.DefaultType
import com.mapr.db.spark.utils.MapRSpark
import org.ojai.{Document, Value}
Expand All @@ -38,7 +39,10 @@ private[spark] class MapRDBTableScanRDD[T: ClassTag](

@transient private lazy val table = DBClient().getTable(tableName, bufferWrites)
@transient private lazy val tabletinfos =
if (condition == null || condition.condition.isEmpty) {
if (enforceSingleFragment) {
// no need to obtain tablet info for the case of single fragment
Seq.empty
} else if (condition == null || condition.condition.isEmpty) {
DBClient().getTabletInfos(tableName, bufferWrites)
} else DBClient().getTabletInfos(tableName, condition.condition, bufferWrites)
@transient private lazy val getSplits: Seq[Value] = {
Expand All @@ -55,14 +59,17 @@ private[spark] class MapRDBTableScanRDD[T: ClassTag](

private def getPartitioner: Partitioner = {
if (getSplits.isEmpty) {
null
} else if (getSplits(0).getType == Value.Type.STRING) {
if (enforceSingleFragment) MapRDBPartitioner[String](Seq.empty) else null
} else if (getSplits.head.getType == Value.Type.STRING) {
MapRDBPartitioner(getSplits.map(_.getString))
} else {
MapRDBPartitioner(getSplits.map(_.getBinary))
}
}

private def enforceSingleFragment =
queryOptions.getOrElse(SingleFragmentOption, "false").toBoolean

def toDF[T <: Product: TypeTag](): DataFrame = maprspark[T]()

def maprspark[T <: Product: TypeTag](): DataFrame = {
Expand All @@ -84,16 +91,26 @@ private[spark] class MapRDBTableScanRDD[T: ClassTag](
override type Self = MapRDBTableScanRDD[T]

override def getPartitions: Array[Partition] = {
val splits = tabletinfos.zipWithIndex.map(a => {
val tabcond = a._1.getCondition
MaprDBPartition(a._2,
tableName,
a._1.getLocations,
DBClient().getEstimatedSize(a._1),
DBQueryCondition(tabcond)).asInstanceOf[Partition]
})
logDebug("Partitions for the table:" + tableName + " are " + splits)
splits.toArray
if (enforceSingleFragment) {
val index = 0
val partition = MaprDBPartition(index,
tableName,
Seq.empty,
0,
DBQueryCondition(DBClient().newCondition().build()))
Array(partition)
} else {
val splits = tabletinfos.zipWithIndex.map(a => {
val tableCond = a._1.getCondition
MaprDBPartition(a._2,
tableName,
a._1.getLocations,
DBClient().getEstimatedSize(a._1),
DBQueryCondition(tableCond)).asInstanceOf[Partition]
})
logDebug("Partitions for the table:" + tableName + " are " + splits)
splits.toArray
}
}

override def getPreferredLocations(split: Partition): Seq[String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class DefaultSource
parameters.get("ColumnProjection"),
parameters.getOrElse("Operation", "InsertOrReplace"),
parameters.getOrElse("FailOnConflict", "false"),
parameters.filterKeys(k => k.startsWith("ojai.mapr.query")).map(identity)
parameters.filterKeys(k =>
k.startsWith("ojai.mapr.query") || k.startsWith("spark.maprdb")).map(identity)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,21 @@ case class SparkSessionFunctions(@transient sparkSession: SparkSession,
.build()
.toDF[T](schema, sampleSize, bufferWrites)
}

def loadFromMapRDBWithSingleFragment[T <: Product: TypeTag](
tableName: String,
schema: StructType = null,
sampleSize: Double = GenerateSchema.SAMPLE_SIZE): DataFrame = {

MapRSpark
.builder()
.sparkSession(sparkSession)
.configuration()
.setTable(tableName)
.setBufferWrites(bufferWrites)
.setHintUsingIndex(hintUsingIndex)
.setQueryOptions(queryOptions + (SingleFragmentOption -> "true"))
.build()
.toDF[T](schema, sampleSize, bufferWrites)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import org.apache.spark.sql._

package object sql {

val SingleFragmentOption = "spark.maprdb.enforce_single_fragment"

implicit def toSparkSessionFunctions(sqlContext: SQLContext): SparkSessionFunctions = {
toSparkSessionFunctions(sqlContext.sparkSession)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ case class MapRSpark(sparkSession: Option[SparkSession],
.option("sampleSize", sampleSize)
.option("bufferWrites", bufferWrites)

queryOptions.get.foreach(option => reader.option(option._1, option._2))

if (cond.isDefined) {
reader.option("QueryCondition",
new String(
Expand Down

0 comments on commit 46c4a9e

Please sign in to comment.