From cb74ad60f672c6ac9a2b915d7db9943750169d5f Mon Sep 17 00:00:00 2001 From: Luc Talatinian Date: Fri, 10 Jan 2025 12:45:46 -0500 Subject: [PATCH] fix potential nil deref in waiter path matcher --- .../GoJmespathExpressionGenerator.java | 17 ++++++- .../GoJmespathExpressionGeneratorTest.java | 48 ++++++++++++------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java index 30117776..ec11c078 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java @@ -16,6 +16,7 @@ package software.amazon.smithy.go.codegen; import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.SymbolUtils.isNilable; import static software.amazon.smithy.go.codegen.SymbolUtils.isPointable; import static software.amazon.smithy.go.codegen.SymbolUtils.sliceOf; import static software.amazon.smithy.go.codegen.util.ShapeUtil.BOOL_SHAPE; @@ -280,7 +281,21 @@ private Variable visitProjection(ProjectionExpression expr, Variable current) { private Variable visitSub(Subexpression expr, Variable current) { var left = visit(expr.getLeft(), current); - return visit(expr.getRight(), left); + if (!isNilable(left.type)) { + return visit(expr.getRight(), left); + } + + var lookahead = new GoJmespathExpressionGenerator(ctx, new GoWriter("")) + .generate(expr.getRight(), left); + var ident = nextIdent(); + writer.write("var $L $P", ident, lookahead.type); + writer.write("if $L != nil {", left.ident); + writer.indent(); + var inner = visit(expr.getRight(), left); + writer.write("$L = $L", ident, inner.ident); + writer.dedent(); + writer.write("}"); + return new Variable(inner.shape, ident, inner.type); } private Variable visitField(FieldExpression expr, Variable current) { diff --git a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java index a5d4f8a7..764c3537 100644 --- a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java +++ b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java @@ -124,7 +124,11 @@ public void testSubexpression() { assertThat(actual.ident(), Matchers.equalTo("v2")); assertThat(writer.toString(), Matchers.containsString(""" v1 := input.Nested - v2 := v1.NestedField + var v2 *string + if v1 != nil { + v3 := v1.NestedField + v2 = v3 + } """)); } @@ -304,14 +308,18 @@ public void testComparatorStringLHSNil() { "input" )); assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE)); - assertThat(actual.ident(), Matchers.equalTo("v4")); + assertThat(actual.ident(), Matchers.equalTo("v5")); assertThat(writer.toString(), Matchers.containsString(""" v1 := input.Nested - v2 := v1.NestedField - v3 := "foo" - var v4 bool + var v2 *string + if v1 != nil { + v3 := v1.NestedField + v2 = v3 + } + v4 := "foo" + var v5 bool if v2 != nil { - v4 = string(*v2) == string(v3) + v5 = string(*v2) == string(v4) } """)); } @@ -327,14 +335,18 @@ public void testComparatorStringRHSNil() { "input" )); assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE)); - assertThat(actual.ident(), Matchers.equalTo("v4")); + assertThat(actual.ident(), Matchers.equalTo("v5")); assertThat(writer.toString(), Matchers.containsString(""" v1 := "foo" v2 := input.Nested - v3 := v2.NestedField - var v4 bool + var v3 *string + if v2 != nil { + v4 := v2.NestedField + v3 = v4 + } + var v5 bool if v3 != nil { - v4 = string(v1) == string(*v3) + v5 = string(v1) == string(*v3) } """)); } @@ -350,14 +362,18 @@ public void testComparatorStringBothNil() { "input" )); assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE)); - assertThat(actual.ident(), Matchers.equalTo("v4")); + assertThat(actual.ident(), Matchers.equalTo("v5")); assertThat(writer.toString(), Matchers.containsString(""" v1 := input.Nested - v2 := v1.NestedField - v3 := input.SimpleShape - var v4 bool - if v2 != nil && v3 != nil { - v4 = string(*v2) == string(*v3) + var v2 *string + if v1 != nil { + v3 := v1.NestedField + v2 = v3 + } + v4 := input.SimpleShape + var v5 bool + if v2 != nil && v4 != nil { + v5 = string(*v2) == string(*v4) } """)); }