Skip to content

Commit

Permalink
feat: Add some combinators to BQSchema with test
Browse files Browse the repository at this point in the history
  • Loading branch information
hamnis committed Jun 14, 2023
1 parent 89f6b1c commit 8a53974
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 2 deletions.
4 changes: 3 additions & 1 deletion core/src/main/scala/no/nrk/bigquery/BQField.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ case class BQField(
/** see description in [[BQSchema.recursivelyNullable]] */
def recursivelyNullable: BQField =
copy(
mode = if (mode == Field.Mode.REQUIRED) Field.Mode.NULLABLE else mode,
mode = if (isRequired) Field.Mode.NULLABLE else mode,
subFields = subFields.map(_.recursivelyNullable)
)

def isRequired: Boolean = mode == Field.Mode.REQUIRED
}

object BQField {
Expand Down
25 changes: 24 additions & 1 deletion core/src/main/scala/no/nrk/bigquery/BQSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,30 @@ case class BQSchema(fields: List[BQField]) {
def recursivelyNullable: BQSchema =
copy(fields = fields.map(_.recursivelyNullable))

def extend(additional: BQSchema) = BQSchema(fields ::: additional.fields)
/** Return all fields that are required. If a struct has required fields, but is not itself required, returns that.
* This is a companion to [[recursivelyNullable]]
*/
def requiredFields: List[BQField] = {
def go(field: BQField, list: List[BQField]): List[BQField] = {
val children = field.subFields.flatMap(go(_, list))
if (children.nonEmpty || field.isRequired) {
field :: list
} else {
list
}
}

fields.flatMap(go(_, Nil))
}

def extend(additional: BQSchema): BQSchema = BQSchema(fields ::: additional.fields)

def extendWith(additional: BQField*): BQSchema = BQSchema(fields ::: additional.toList)

def filter(predicate: BQField => Boolean): BQSchema = BQSchema(fields.filter(predicate))

def filterNot(predicate: BQField => Boolean): BQSchema = BQSchema(fields.filterNot(predicate))

}

object BQSchema {
Expand Down
67 changes: 67 additions & 0 deletions core/src/test/scala/no/nrk/bigquery/BQSchemaTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package no.nrk.bigquery

import com.google.cloud.bigquery.Field.Mode
import com.google.cloud.bigquery.StandardSQLTypeName

class BQSchemaTest extends munit.FunSuite {

test("extend") {
val nameField = BQField.struct("name", Mode.REQUIRED)(BQField("id", StandardSQLTypeName.STRING, Mode.REQUIRED))
val partitionDate = BQField("partitionDate", StandardSQLTypeName.DATE, Mode.REQUIRED)
val baseSchema = BQSchema.of(nameField)
val extended = baseSchema.extend(BQSchema.of(partitionDate))
val extendedWith = baseSchema.extendWith(partitionDate)
val expected = BQSchema(List(nameField, partitionDate))

assertEquals(extended, expected)
assertEquals(extendedWith, expected)
}

test("findRequired") {
val nameField = BQField.struct("name", Mode.REQUIRED)(BQField("id", StandardSQLTypeName.STRING, Mode.REQUIRED))
val insertedAt = BQField("insertedAt", StandardSQLTypeName.TIMESTAMP, Mode.NULLABLE)
val schema = BQSchema.of(nameField, insertedAt)

assertEquals(schema.requiredFields, List(nameField))
}

test("findRequired2") {
val nameField = BQField.struct("name", Mode.REPEATED)(BQField("id", StandardSQLTypeName.STRING, Mode.REQUIRED))
val insertedAt = BQField("insertedAt", StandardSQLTypeName.TIMESTAMP, Mode.NULLABLE)
val schema = BQSchema.of(nameField, insertedAt)

assertEquals(schema.requiredFields, List(nameField))
}

test("recursiveNullable") {
val nameField = BQField.struct("name", Mode.REPEATED)(BQField("id", StandardSQLTypeName.STRING, Mode.REQUIRED))
val insertedAt = BQField("insertedAt", StandardSQLTypeName.TIMESTAMP, Mode.NULLABLE)
val schema = BQSchema.of(nameField, insertedAt)

val nameFieldExpected =
BQField.struct("name", Mode.REPEATED)(BQField("id", StandardSQLTypeName.STRING, Mode.NULLABLE))

assertEquals(schema.recursivelyNullable, BQSchema.of(nameFieldExpected, insertedAt))
}

test("filter") {
val nameField = BQField.struct("name", Mode.REPEATED)(BQField("id", StandardSQLTypeName.STRING, Mode.REQUIRED))
val insertedAt = BQField("insertedAt", StandardSQLTypeName.TIMESTAMP, Mode.NULLABLE)
val schema = BQSchema.of(nameField, insertedAt)

val filtered = schema.filter(_.name == "name")

assertEquals(filtered, BQSchema.of(nameField))
}

test("filterNot") {
val nameField = BQField.struct("name", Mode.REPEATED)(BQField("id", StandardSQLTypeName.STRING, Mode.REQUIRED))
val insertedAt = BQField("insertedAt", StandardSQLTypeName.TIMESTAMP, Mode.NULLABLE)
val schema = BQSchema.of(nameField, insertedAt)

val filtered = schema.filterNot(_.name == "insertedAt")

assertEquals(filtered, BQSchema.of(nameField))
}

}

0 comments on commit 8a53974

Please sign in to comment.