Skip to content

Commit

Permalink
Register DoFns
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCutter committed Sep 5, 2023
1 parent 71ada31 commit 3343c95
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
14 changes: 13 additions & 1 deletion experimental/batchmap/tmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"

"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/register"

"github.com/google/trillian/merkle/coniks"
"github.com/google/trillian/merkle/smt"
Expand All @@ -38,6 +39,15 @@ var (
cntTilesUpdated = beam.NewCounter("batchmap", "tiles-updated")
)

func init() {
register.DoFn1x2[nodeHash, []byte, nodeHash](&leafShardFn{})
register.DoFn3x2[context.Context, []byte, func(*nodeHash) bool, *Tile, error](&tileHashFn{})
register.DoFn4x2[context.Context, []byte, func(**Tile) bool, func(*nodeHash) bool, *Tile, error](&tileUpdateFn{})
register.Function5x1(createStratum)
register.Function6x1(updateStratum)
register.Function1x2(tilePathFn)
}

// Create builds a new map from the given PCollection of *Entry. Outputs
// the resulting Merkle tree tiles as a PCollection of *Tile.
//
Expand Down Expand Up @@ -122,11 +132,13 @@ func createStratum(s beam.Scope, leaves beam.PCollection, treeID int64, hash cry
// output is a PCollection of *Tile.
func updateStratum(s beam.Scope, base, deltas beam.PCollection, treeID int64, hash crypto.Hash, rootDepth int) beam.PCollection {
s = s.Scope(fmt.Sprintf("updateStratum-%d", rootDepth))
shardedBase := beam.ParDo(s, func(t *Tile) ([]byte, *Tile) { return t.Path, t }, base)
shardedBase := beam.ParDo(s, tilePathFn, base)
shardedDelta := beam.ParDo(s, &leafShardFn{RootDepthBytes: rootDepth}, deltas)
return beam.ParDo(s, &tileUpdateFn{TreeID: treeID, Hash: hash}, beam.CoGroupByKey(s, shardedBase, shardedDelta))
}

func tilePathFn(t *Tile) ([]byte, *Tile) { return t.Path, t }

// nodeHash describes a leaf to be included in a tile.
// This is logically the same as smt.Node however it has public fields so is
// serializable by the default Beam coder. Also, it allows changes to be made
Expand Down
37 changes: 26 additions & 11 deletions experimental/batchmap/tmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"testing"

"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
"github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/filter"
Expand All @@ -29,6 +30,15 @@ import (

const hash = crypto.SHA512_256

func init() {
register.Function1x1(countTilesFn)
register.Function1x1(testFilterRootOnlyFn)
register.Function1x1(testFnTileRootHash)
register.Function1x1(testLeavesSortedFn)
}

func testFnTileRootHash(t *Tile) string { return fmt.Sprintf("%x", t.RootHash) }

func TestMain(m *testing.M) {
ptest.Main(m)
}
Expand Down Expand Up @@ -132,8 +142,8 @@ func TestCreate(t *testing.T) {
if test.wantFailConstruct {
return
}
rootTile := filter.Include(s, tiles, func(t *Tile) bool { return len(t.Path) == 0 })
roots := beam.ParDo(s, func(t *Tile) string { return fmt.Sprintf("%x", t.RootHash) }, rootTile)
rootTile := filter.Include(s, tiles, testFilterRootOnlyFn)
roots := beam.ParDo(s, testFnTileRootHash, rootTile)

assertTileCount(s, tiles, test.wantTileCount)
passert.Equals(s, roots, test.wantRoot)
Expand Down Expand Up @@ -219,8 +229,8 @@ func TestUpdate(t *testing.T) {
if err != nil {
t.Errorf("pipeline construction failure: %v", err)
}
rootTile := filter.Include(s, tiles, func(t *Tile) bool { return len(t.Path) == 0 })
roots := beam.ParDo(s, func(t *Tile) string { return fmt.Sprintf("%x", t.RootHash) }, rootTile)
rootTile := filter.Include(s, tiles, testFilterRootOnlyFn)
roots := beam.ParDo(s, testFnTileRootHash, rootTile)

assertTileCount(s, tiles, test.wantTileCount)
passert.Equals(s, roots, test.wantRoot)
Expand All @@ -232,6 +242,8 @@ func TestUpdate(t *testing.T) {
}
}

func testFilterRootOnlyFn(t *Tile) bool { return len(t.Path) == 0 }

func TestChildrenSorted(t *testing.T) {
p, s := beam.NewPipelineWithRoot()
entries := []*Entry{}
Expand All @@ -244,13 +256,15 @@ func TestChildrenSorted(t *testing.T) {
t.Fatalf("failed to create pipeline: %v", err)
}

passert.True(s, tiles, func(t *Tile) bool { return isStrictlySorted(t.Leaves) })
passert.True(s, tiles, testLeavesSortedFn)

if err := ptest.Run(p); err != nil {
t.Fatalf("pipeline failed: %v", err)
}
}

func testLeavesSortedFn(t *Tile) bool { return isStrictlySorted(t.Leaves) }

func TestGoldenCreate(t *testing.T) {
p, s := beam.NewPipelineWithRoot()
leaves := beam.CreateList(s, leafNodes(t, 500))
Expand All @@ -259,8 +273,8 @@ func TestGoldenCreate(t *testing.T) {
if err != nil {
t.Fatalf("failed to create pipeline: %v", err)
}
rootTile := filter.Include(s, tiles, func(t *Tile) bool { return len(t.Path) == 0 })
roots := beam.ParDo(s, func(t *Tile) string { return fmt.Sprintf("%x", t.RootHash) }, rootTile)
rootTile := filter.Include(s, tiles, testFilterRootOnlyFn)
roots := beam.ParDo(s, testFnTileRootHash, rootTile)

assertTileCount(s, tiles, 1218)
passert.Equals(s, roots, "daf17dc2c83f37962bae8a65d294ef7fca4ffa02c10bdc4ca5c4dec408001c98")
Expand All @@ -286,8 +300,8 @@ func TestGoldenUpdate(t *testing.T) {
t.Fatalf("failed to create v1 pipeline: %v", err)
}

rootTile := filter.Include(s, tiles, func(t *Tile) bool { return len(t.Path) == 0 })
roots := beam.ParDo(s, func(t *Tile) string { return fmt.Sprintf("%x", t.RootHash) }, rootTile)
rootTile := filter.Include(s, tiles, testFilterRootOnlyFn)
roots := beam.ParDo(s, testFnTileRootHash, rootTile)

assertTileCount(s, tiles, 1218)
passert.Equals(s, roots, "daf17dc2c83f37962bae8a65d294ef7fca4ffa02c10bdc4ca5c4dec408001c98")
Expand All @@ -300,10 +314,11 @@ func TestGoldenUpdate(t *testing.T) {
// tiles has the given cardinality. If the check fails then ptest.Run will
// return an error.
func assertTileCount(s beam.Scope, tiles beam.PCollection, count int) {
countTiles := func(t *Tile) int { return 1 }
passert.Equals(s, stats.Sum(s, beam.ParDo(s, countTiles, tiles)), count)
passert.Equals(s, stats.Sum(s, beam.ParDo(s, countTilesFn, tiles)), count)
}

func countTilesFn(t *Tile) int { return 1 }

// Copied from http://google3/third_party/golang/trillian/merkle/smt/hstar3_test.go?l=201&rcl=298994396
func leafNodes(t testing.TB, n int) []*Entry {
t.Helper()
Expand Down

0 comments on commit 3343c95

Please sign in to comment.