Skip to content

Commit

Permalink
Fixes #93 - added ErrorHandlingFilterRowsWithErrors test file
Browse files Browse the repository at this point in the history
  • Loading branch information
TebaleloS committed Apr 6, 2023
1 parent 3efc1b5 commit 38ad8a6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package za.co.absa.spark.commons.errorhandling.implementations

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.functions.{coalesce, col, collect_list}
import za.co.absa.spark.commons.errorhandling.{ErrorMessageSubmit}
import org.apache.spark.sql.functions.{coalesce, col, lit}
import za.co.absa.spark.commons.errorhandling.ErrorMessageSubmit
import za.co.absa.spark.commons.errorhandling.partials.ErrorHandlingCommon
import za.co.absa.spark.commons.errorhandling.types.ErrorColumn

Expand All @@ -23,23 +23,23 @@ object ErrorHandlingFilterRowsWithErrors extends ErrorHandlingCommon {
}

/**
* Evaluate the given column to check if it has errors
* @param errorMessageSubmit the object that has to be evaluated for error purposes
* @return returns the columns with error
* Checks if given column has errors or not
* @param errorMessageSubmit the object that defines the structure of the column
* @return returns true if the column contains an error
*/
override protected def evaluate(errorMessageSubmit: ErrorMessageSubmit): Column = {
errorMessageSubmit.errMsg.column
lit(true)
}

/**
* Checks for relationship of the provided clumn in the given dataframe.
* Checks for relationship of the provided column in the given dataframe.
* @param dataFrame the overall data structure that need to be aggregated
* @param errCols the row to aggregate the dataframe with
* @param errCols the columns to aggregate the dataframe with
* @return Returns the aggregated dataset with errors.
*/
override protected def doTheAggregation(dataFrame: DataFrame, errCols: Column*): DataFrame = {
val aggregatedDF = dataFrame.groupBy("errCode")
.agg(coalesce(collect_list("errCols")) as "AggregatedError")
.agg(coalesce(errCols: _*, lit(false)) as "AggregatedError")
aggregatedDF.filter(!col("AggregatedError"))
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package za.co.absa.spark.commons.errorhandling.implementations
import org.scalatest.funsuite.AnyFunSuite

class ErrorHandlingFilterRowsWithErrorsTest extends AnyFunSuite {

}

0 comments on commit 38ad8a6

Please sign in to comment.