From 3343c9596737f664bf72a26c9fa56cc8bb00f97c Mon Sep 17 00:00:00 2001 From: Al Cutter Date: Tue, 5 Sep 2023 11:27:22 +0100 Subject: [PATCH] Register DoFns --- experimental/batchmap/tmap.go | 14 ++++++++++- experimental/batchmap/tmap_test.go | 37 +++++++++++++++++++++--------- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/experimental/batchmap/tmap.go b/experimental/batchmap/tmap.go index a8ec59df3b..107fa166ea 100644 --- a/experimental/batchmap/tmap.go +++ b/experimental/batchmap/tmap.go @@ -25,6 +25,7 @@ import ( "fmt" "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/google/trillian/merkle/coniks" "github.com/google/trillian/merkle/smt" @@ -38,6 +39,15 @@ var ( cntTilesUpdated = beam.NewCounter("batchmap", "tiles-updated") ) +func init() { + register.DoFn1x2[nodeHash, []byte, nodeHash](&leafShardFn{}) + register.DoFn3x2[context.Context, []byte, func(*nodeHash) bool, *Tile, error](&tileHashFn{}) + register.DoFn4x2[context.Context, []byte, func(**Tile) bool, func(*nodeHash) bool, *Tile, error](&tileUpdateFn{}) + register.Function5x1(createStratum) + register.Function6x1(updateStratum) + register.Function1x2(tilePathFn) +} + // Create builds a new map from the given PCollection of *Entry. Outputs // the resulting Merkle tree tiles as a PCollection of *Tile. // @@ -122,11 +132,13 @@ func createStratum(s beam.Scope, leaves beam.PCollection, treeID int64, hash cry // output is a PCollection of *Tile. func updateStratum(s beam.Scope, base, deltas beam.PCollection, treeID int64, hash crypto.Hash, rootDepth int) beam.PCollection { s = s.Scope(fmt.Sprintf("updateStratum-%d", rootDepth)) - shardedBase := beam.ParDo(s, func(t *Tile) ([]byte, *Tile) { return t.Path, t }, base) + shardedBase := beam.ParDo(s, tilePathFn, base) shardedDelta := beam.ParDo(s, &leafShardFn{RootDepthBytes: rootDepth}, deltas) return beam.ParDo(s, &tileUpdateFn{TreeID: treeID, Hash: hash}, beam.CoGroupByKey(s, shardedBase, shardedDelta)) } +func tilePathFn(t *Tile) ([]byte, *Tile) { return t.Path, t } + // nodeHash describes a leaf to be included in a tile. // This is logically the same as smt.Node however it has public fields so is // serializable by the default Beam coder. Also, it allows changes to be made diff --git a/experimental/batchmap/tmap_test.go b/experimental/batchmap/tmap_test.go index 33415c36c5..6c2af079e4 100644 --- a/experimental/batchmap/tmap_test.go +++ b/experimental/batchmap/tmap_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/filter" @@ -29,6 +30,15 @@ import ( const hash = crypto.SHA512_256 +func init() { + register.Function1x1(countTilesFn) + register.Function1x1(testFilterRootOnlyFn) + register.Function1x1(testFnTileRootHash) + register.Function1x1(testLeavesSortedFn) +} + +func testFnTileRootHash(t *Tile) string { return fmt.Sprintf("%x", t.RootHash) } + func TestMain(m *testing.M) { ptest.Main(m) } @@ -132,8 +142,8 @@ func TestCreate(t *testing.T) { if test.wantFailConstruct { return } - rootTile := filter.Include(s, tiles, func(t *Tile) bool { return len(t.Path) == 0 }) - roots := beam.ParDo(s, func(t *Tile) string { return fmt.Sprintf("%x", t.RootHash) }, rootTile) + rootTile := filter.Include(s, tiles, testFilterRootOnlyFn) + roots := beam.ParDo(s, testFnTileRootHash, rootTile) assertTileCount(s, tiles, test.wantTileCount) passert.Equals(s, roots, test.wantRoot) @@ -219,8 +229,8 @@ func TestUpdate(t *testing.T) { if err != nil { t.Errorf("pipeline construction failure: %v", err) } - rootTile := filter.Include(s, tiles, func(t *Tile) bool { return len(t.Path) == 0 }) - roots := beam.ParDo(s, func(t *Tile) string { return fmt.Sprintf("%x", t.RootHash) }, rootTile) + rootTile := filter.Include(s, tiles, testFilterRootOnlyFn) + roots := beam.ParDo(s, testFnTileRootHash, rootTile) assertTileCount(s, tiles, test.wantTileCount) passert.Equals(s, roots, test.wantRoot) @@ -232,6 +242,8 @@ func TestUpdate(t *testing.T) { } } +func testFilterRootOnlyFn(t *Tile) bool { return len(t.Path) == 0 } + func TestChildrenSorted(t *testing.T) { p, s := beam.NewPipelineWithRoot() entries := []*Entry{} @@ -244,13 +256,15 @@ func TestChildrenSorted(t *testing.T) { t.Fatalf("failed to create pipeline: %v", err) } - passert.True(s, tiles, func(t *Tile) bool { return isStrictlySorted(t.Leaves) }) + passert.True(s, tiles, testLeavesSortedFn) if err := ptest.Run(p); err != nil { t.Fatalf("pipeline failed: %v", err) } } +func testLeavesSortedFn(t *Tile) bool { return isStrictlySorted(t.Leaves) } + func TestGoldenCreate(t *testing.T) { p, s := beam.NewPipelineWithRoot() leaves := beam.CreateList(s, leafNodes(t, 500)) @@ -259,8 +273,8 @@ func TestGoldenCreate(t *testing.T) { if err != nil { t.Fatalf("failed to create pipeline: %v", err) } - rootTile := filter.Include(s, tiles, func(t *Tile) bool { return len(t.Path) == 0 }) - roots := beam.ParDo(s, func(t *Tile) string { return fmt.Sprintf("%x", t.RootHash) }, rootTile) + rootTile := filter.Include(s, tiles, testFilterRootOnlyFn) + roots := beam.ParDo(s, testFnTileRootHash, rootTile) assertTileCount(s, tiles, 1218) passert.Equals(s, roots, "daf17dc2c83f37962bae8a65d294ef7fca4ffa02c10bdc4ca5c4dec408001c98") @@ -286,8 +300,8 @@ func TestGoldenUpdate(t *testing.T) { t.Fatalf("failed to create v1 pipeline: %v", err) } - rootTile := filter.Include(s, tiles, func(t *Tile) bool { return len(t.Path) == 0 }) - roots := beam.ParDo(s, func(t *Tile) string { return fmt.Sprintf("%x", t.RootHash) }, rootTile) + rootTile := filter.Include(s, tiles, testFilterRootOnlyFn) + roots := beam.ParDo(s, testFnTileRootHash, rootTile) assertTileCount(s, tiles, 1218) passert.Equals(s, roots, "daf17dc2c83f37962bae8a65d294ef7fca4ffa02c10bdc4ca5c4dec408001c98") @@ -300,10 +314,11 @@ func TestGoldenUpdate(t *testing.T) { // tiles has the given cardinality. If the check fails then ptest.Run will // return an error. func assertTileCount(s beam.Scope, tiles beam.PCollection, count int) { - countTiles := func(t *Tile) int { return 1 } - passert.Equals(s, stats.Sum(s, beam.ParDo(s, countTiles, tiles)), count) + passert.Equals(s, stats.Sum(s, beam.ParDo(s, countTilesFn, tiles)), count) } +func countTilesFn(t *Tile) int { return 1 } + // Copied from http://google3/third_party/golang/trillian/merkle/smt/hstar3_test.go?l=201&rcl=298994396 func leafNodes(t testing.TB, n int) []*Entry { t.Helper()