diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsEventStreamUtils.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsEventStreamUtils.java index c6e9d12f1ce..f35098e5bfc 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsEventStreamUtils.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsEventStreamUtils.java @@ -3,6 +3,7 @@ import java.util.Set; import java.util.TreeSet; import java.util.stream.Collectors; + import software.amazon.smithy.codegen.core.CodegenException; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.codegen.core.SymbolProvider; @@ -88,7 +89,7 @@ public static void generateEventStreamComponents(GenerationContext context) { for (OperationShape operationShape : operationShapes) { if (streamIndex.getInputInfo(operationShape).isEmpty() - && streamIndex.getOutputInfo(operationShape).isEmpty()) { + && streamIndex.getOutputInfo(operationShape).isEmpty()) { continue; } generateEventStreamMiddleware(context, operationShape, !isHttpBindingProto); @@ -107,18 +108,18 @@ private static void generateUnknownEventMessageError(GenerationContext context) var message = getEventStreamSymbol("Message"); writer.write(""" - // $T provides an error when a message is received from the stream, - // but the reader is unable to determine what kind of message it is. - type $T struct { - Type string - Message $P - } - - // Error retruns the error message string. - func (e $P) Error() string { - return "unknown event stream message type, " + e.Type - } - """, symbol, symbol, message, symbol); + // $T provides an error when a message is received from the stream, + // but the reader is unable to determine what kind of message it is. + type $T struct { + Type string + Message $P + } + + // Error retruns the error message string. + func (e $P) Error() string { + return "unknown event stream message type, " + e.Type + } + """, symbol, symbol, message, symbol); } private static Symbol getUnknownEventMessageErrorSymbol() { @@ -142,17 +143,17 @@ private static void generateEventStreamClientLogModeFinalizer( var responseStream = streamIndex .getOutputInfo(operationShape).isPresent(); writer.write(""" - case $S: - $T(o, $L, $L) - return - """, operationShape.getId().getName(), + case $S: + $T(o, $L, $L) + return + """, operationShape.getId().getName(), getToggleEventStreamClientLogModeSymbol(), requestStream, responseStream); }); writer.write(""" - default: - return - """); + default: + return + """); }); }); } @@ -168,20 +169,20 @@ private static void generateToggleClientLogModeFinalizer(GenerationContext conte writer.openBlock("func $T(o *Options, request, response bool) {", "}", getToggleEventStreamClientLogModeSymbol(), () -> { writer.write(""" - mode := o.ClientLogMode - - if request && mode.IsRequestWithBody() { - mode.ClearRequestWithBody() - mode |= $T - } - - if response && mode.IsResponseWithBody() { - mode.ClearResponseWithBody() - mode |= $T - } - - o.ClientLogMode = mode - """, logRequest, logResponse + mode := o.ClientLogMode + + if request && mode.IsRequestWithBody() { + mode.ClearRequestWithBody() + mode |= $T + } + + if response && mode.IsResponseWithBody() { + mode.ClearResponseWithBody() + mode |= $T + } + + o.ClientLogMode = mode + """, logRequest, logResponse ); }).write(""); } @@ -251,26 +252,26 @@ defer func() { if (inputInfo.isPresent()) { w.write(""" - if err := $T(request); err != nil { - return out, metadata, err - } - """, getEventStreamApiSymbol("ApplyHTTPTransportFixes")) + if err := $T(request); err != nil { + return out, metadata, err + } + """, getEventStreamApiSymbol("ApplyHTTPTransportFixes")) .write(""); w.write(""" - requestSignature, err := $T(request.Request) - if err != nil { - return out, metadata, $T("failed to get event stream seed signature: %v", err) - } - """, getSignedRequestSignature, errorf).write("") + requestSignature, err := $T(request.Request) + if err != nil { + return out, metadata, $T("failed to get event stream seed signature: %v", err) + } + """, getSignedRequestSignature, errorf).write("") .openBlock("signer := $T(", ")", getSymbol("NewStreamSigner", AwsGoDependency.AWS_SIGNER_V4, false), () -> w .write(""" - $T(ctx), - $T(ctx), - $T(ctx), - requestSignature, - """, getSymbol("GetSigningCredentials", + $T(ctx), + $T(ctx), + $T(ctx), + requestSignature, + """, getSymbol("GetSigningCredentials", AwsGoDependency.AWS_MIDDLEWARE, false), getSymbol("GetSigningName", AwsGoDependency.AWS_MIDDLEWARE, false), @@ -290,9 +291,9 @@ defer func() { .openBlock("$T(func(options $P) {", "}),", newEncoder, encoderOptions, () -> w .write(""" - options.Logger = logger - options.LogMessages = m.LogEventStreamWrites - """)) + options.Logger = logger + options.LogMessages = m.LogEventStreamWrites + """)) .write("signer,"); if (withInitialMessages) { w.write("$L,", getEventStreamMessageRequestSerializerName( @@ -301,13 +302,13 @@ defer func() { } }) .write(""" - defer func() { - if err == nil { - return - } - _ = eventWriter.Close() - }() - """); + defer func() { + if err == nil { + return + } + _ = eventWriter.Close() + }() + """); if (withInitialMessages) { w.write(""" @@ -337,20 +338,20 @@ defer close(reqSend) } w.write(""" - deserializeOutput, ok := out.RawResponse.($P) - if !ok { - return out, metadata, $T("unknown transport type: %T", out.RawResponse) - } - _ = deserializeOutput - - output, ok := out.Result.($P) - if out.Result != nil && !ok { - return out, metadata, $T("unexpected output result type: %T", out.Result) - } else if out.Result == nil { - output = &$T{} - out.Result = output - } - """, getSymbol("Response", SmithyGoDependency.SMITHY_HTTP_TRANSPORT), errorf, + deserializeOutput, ok := out.RawResponse.($P) + if !ok { + return out, metadata, $T("unknown transport type: %T", out.RawResponse) + } + _ = deserializeOutput + + output, ok := out.Result.($P) + if out.Result != nil && !ok { + return out, metadata, $T("unexpected output result type: %T", out.Result) + } else if out.Result == nil { + output = &$T{} + out.Result = output + } + """, getSymbol("Response", SmithyGoDependency.SMITHY_HTTP_TRANSPORT), errorf, outputSymbol, errorf, outputSymbol ); @@ -366,9 +367,9 @@ defer close(reqSend) .openBlock("$T(func(options $P) {", "}),", newDecoder, decoderOptions, () -> w .write(""" - options.Logger = logger - options.LogMessages = m.LogEventStreamReads - """)); + options.Logger = logger + options.LogMessages = m.LogEventStreamReads + """)); if (withInitialMessages) { w.write("$L,", getEventStreamMessageResponseDeserializerName( operationShape.getOutput().get(), serviceShape, @@ -376,13 +377,13 @@ defer close(reqSend) } }) .write(""" - defer func() { - if err == nil { - return - } - _ = eventReader.Close() - }() - """); + defer func() { + if err == nil { + return + } + _ = eventReader.Close() + }() + """); if (withInitialMessages) { w.write(""" @@ -414,35 +415,35 @@ defer func() { .write("return out, metadata, nil"); }, (mg, w) -> w.write(""" - LogEventStreamWrites bool - LogEventStreamReads bool - """)); + LogEventStreamWrites bool + LogEventStreamReads bool + """)); var deserializeOutput = getSymbol("DeserializeOutput", SmithyGoDependency.SMITHY_MIDDLEWARE); var httpResponse = getSymbol("Response", SmithyGoDependency.SMITHY_HTTP_TRANSPORT); var copy = getSymbol("Copy", SmithyGoDependency.IO); var discard = getSymbol("Discard", SmithyGoDependency.IOUTIL); writer.write(""" - - func ($P) closeResponseBody(out $T) { - if resp, ok := out.RawResponse.($P); ok && resp != nil && resp.Body != nil { - _, _ = $T($T, resp.Body) - _ = resp.Body.Close() - } - } - """, middleware.getMiddlewareSymbol(), deserializeOutput, httpResponse, copy, discard); + + func ($P) closeResponseBody(out $T) { + if resp, ok := out.RawResponse.($P); ok && resp != nil && resp.Body != nil { + _, _ = $T($T, resp.Body) + _ = resp.Body.Close() + } + } + """, middleware.getMiddlewareSymbol(), deserializeOutput, httpResponse, copy, discard); var stack = getSymbol("Stack", SmithyGoDependency.SMITHY_MIDDLEWARE); var before = getSymbol("Before", SmithyGoDependency.SMITHY_MIDDLEWARE); writer.write(""" - func $T(stack $P, options Options) error { - return stack.Deserialize.Insert(&$T{ - LogEventStreamWrites: options.ClientLogMode.IsRequestEventMessage(), - LogEventStreamReads: options.ClientLogMode.IsResponseEventMessage(), - }, "OperationDeserializer", $T) - } - """, getAddEventStreamOperationMiddlewareSymbol(operationShape), + func $T(stack $P, options Options) error { + return stack.Deserialize.Insert(&$T{ + LogEventStreamWrites: options.ClientLogMode.IsRequestEventMessage(), + LogEventStreamReads: options.ClientLogMode.IsResponseEventMessage(), + }, "OperationDeserializer", $T) + } + """, getAddEventStreamOperationMiddlewareSymbol(operationShape), stack, middleware.getMiddlewareSymbol(), before); } @@ -450,7 +451,7 @@ private static void generateEventSignerInterface(GoSettings settings, GoWriter w writer.openBlock("type $T interface {", "}", getModuleSymbol(settings, EVENT_STREAM_SIGNER_INTERFACE), () -> { writer.write("GetSignature(ctx context.Context, headers, payload []byte, signingTime time.Time, " - + "optFns ...func($P)) ([]byte, error)", + + "optFns ...func($P)) ([]byte, error)", SymbolUtils.createPointableSymbolBuilder("StreamSignerOptions", AwsGoDependency.AWS_SIGNER_V4) .build()); @@ -492,13 +493,13 @@ private static void generateEventStreamReader( var syncOnce = getSymbol("Once", SmithyGoDependency.SYNC, false); writer.write(""" - stream chan $T - decoder $P - eventStream $T - err $P - payloadBuf []byte - done chan struct{} - closeOnce $T""", eventUnionSymbol, decoderSymbol, readCloser, onceErr, syncOnce); + stream chan $T + decoder $P + eventStream $T + err $P + payloadBuf []byte + done chan struct{} + closeOnce $T""", eventUnionSymbol, decoderSymbol, readCloser, onceErr, syncOnce); if (withInitialMessages) { writer.write("initialResponseDeserializer func($P) (interface{}, error)", messageSymbol); writer.write("initialResponse chan interface{}"); @@ -516,12 +517,12 @@ private static void generateEventStreamReader( var newOnceErr = getSymbol("NewOnceErr", SmithyGoDependency.SMITHY_SYNC, false); writer.openBlock("w := &$T{", "}", readerSymbol, () -> { writer.write(""" - stream: make(chan $T), - decoder: decoder, - eventStream: readCloser, - err: $T(), - done: make(chan struct{}), - payloadBuf: make([]byte, 10*1024),""", eventUnionSymbol, newOnceErr); + stream: make(chan $T), + decoder: decoder, + eventStream: readCloser, + err: $T(), + done: make(chan struct{}), + payloadBuf: make([]byte, 10*1024),""", eventUnionSymbol, newOnceErr); if (withInitialMessages) { writer.write("initialResponseDeserializer: ird,"); writer.write("initialResponse: make(chan interface{}, 1),"); @@ -529,9 +530,9 @@ private static void generateEventStreamReader( }).write(""); writer.write(""" - go w.readEventStream() - - return w"""); + go w.readEventStream() + + return w"""); }).write(""); writer.openBlock("func (r $P) Events() <-chan $T {", "}", readerSymbol, eventUnionSymbol, () -> writer @@ -539,73 +540,73 @@ private static void generateEventStreamReader( writer.openBlock("func (r $P) readEventStream() {", "}", readerSymbol, () -> { writer.write(""" - defer r.Close() - defer close(r.stream) - """); + defer r.Close() + defer close(r.stream) + """); if (withInitialMessages) { writer.write(""" - defer close(r.initialResponse) - """); + defer close(r.initialResponse) + """); } writer.openBlock("for {", "}", () -> { writer.write(""" - r.payloadBuf = r.payloadBuf[0:0] - decodedMessage, err := r.decoder.Decode(r.eventStream, r.payloadBuf) - if err != nil { - if err == $T { - return - } - select { - case <-r.done: - return - default: - r.err.SetError(err) - return - } - } - - event, err := r.deserializeEventMessage(&decodedMessage) - if err != nil { - r.err.SetError(err) - return - } - """, SymbolUtils.createValueSymbolBuilder("EOF", + r.payloadBuf = r.payloadBuf[0:0] + decodedMessage, err := r.decoder.Decode(r.eventStream, r.payloadBuf) + if err != nil { + if err == $T { + return + } + select { + case <-r.done: + return + default: + r.err.SetError(err) + return + } + } + + event, err := r.deserializeEventMessage(&decodedMessage) + if err != nil { + r.err.SetError(err) + return + } + """, SymbolUtils.createValueSymbolBuilder("EOF", SmithyGoDependency.IO).build()); if (withInitialMessages) { writer.write(""" - switch ev := event.(type) { - case $P: - select { - case r.initialResponse <- ev.Value: - case <-r.done: - return - default: - } - case $P: - select { - case r.stream <- ev.Value: - case <-r.done: - return - } - default: - r.err.SetError($T("unexpected event wrapper: %T", event)) - return - } - """, + switch ev := event.(type) { + case $P: + select { + case r.initialResponse <- ev.Value: + case <-r.done: + return + default: + } + case $P: + select { + case r.stream <- ev.Value: + case <-r.done: + return + } + default: + r.err.SetError($T("unexpected event wrapper: %T", event)) + return + } + """, getReaderEventWrapperInitialResponseType(symbolProvider, eventStream, service), getReaderEventWrapperMessageType(symbolProvider, eventStream, service), getSymbol("Errorf", SmithyGoDependency.FMT, false)); } else { writer.write(""" - select { - case r.stream <- event: - case <-r.done: - return - } - """); + select { + case r.stream <- event: + case <-r.done: + return + } + """); } }); } @@ -620,38 +621,38 @@ defer close(r.initialResponse) var errorMessageType = getEventStreamApiSymbol("ErrorMessageType", false); writer.write(""" - messageType := msg.Headers.Get($T) - if messageType == nil { - return nil, $T("%s event header not present", $T) - } - """, messageTypeHeader, errorf, messageTypeHeader) + messageType := msg.Headers.Get($T) + if messageType == nil { + return nil, $T("%s event header not present", $T) + } + """, messageTypeHeader, errorf, messageTypeHeader) .openBlock("switch messageType.String() {", "}", () -> writer .openBlock("case $T:", "", eventMessageType, () -> { if (withInitialMessages) { var eventTypeHeader = getEventStreamApiSymbol("EventTypeHeader", false); writer.write(""" - eventType := msg.Headers.Get($T) - if eventType == nil { - return nil, $T("%s event header not present", $T) - } - - if eventType.String() == "initial-response" { - v, err := r.initialResponseDeserializer(msg) - if err != nil { - return nil, err - } - return &$T{Value: v}, nil - } - """, eventTypeHeader, errorf, eventTypeHeader, + eventType := msg.Headers.Get($T) + if eventType == nil { + return nil, $T("%s event header not present", $T) + } + + if eventType.String() == "initial-response" { + v, err := r.initialResponseDeserializer(msg) + if err != nil { + return nil, err + } + return &$T{Value: v}, nil + } + """, eventTypeHeader, errorf, eventTypeHeader, getReaderEventWrapperInitialResponseType(symbolProvider, eventStream, service)); } writer.write(""" - var v $T - if err := $L(&v, msg); err != nil { - return nil, err - }""", + var v $T + if err := $L(&v, msg); err != nil { + return nil, err + }""", eventUnionSymbol, getEventStreamDeserializerName(eventStream, service, context.getProtocolName())); if (withInitialMessages) { @@ -668,29 +669,29 @@ eventUnionSymbol, getEventStreamDeserializerName(eventStream, context.getProtocolName()))) .openBlock("case $T:", "", errorMessageType, () -> writer .write(""" - errorCode := "UnknownError" - errorMessage := errorCode - if header := msg.Headers.Get($T); header != nil { - errorCode = header.String() - } - if header := msg.Headers.Get($T); header != nil { - errorMessage = header.String() - } - return nil, &$T{ - Code: errorCode, - Message: errorMessage, - } - """, getEventStreamApiSymbol("ErrorCodeHeader", false), + errorCode := "UnknownError" + errorMessage := errorCode + if header := msg.Headers.Get($T); header != nil { + errorCode = header.String() + } + if header := msg.Headers.Get($T); header != nil { + errorMessage = header.String() + } + return nil, &$T{ + Code: errorCode, + Message: errorMessage, + } + """, getEventStreamApiSymbol("ErrorCodeHeader", false), getEventStreamApiSymbol("ErrorMessageHeader", false), getSymbol("GenericAPIError", SmithyGoDependency.SMITHY, false))) .write(""" - default: - mc := msg.Clone() - return nil, &$T{ - Type: messageType.String(), - Message: &mc, - } - """, getUnknownEventMessageErrorSymbol())); + default: + mc := msg.Clone() + return nil, &$T{ + Type: messageType.String(), + Message: &mc, + } + """, getUnknownEventMessageErrorSymbol())); }).write(""); writer.openBlock("func (r $P) ErrorSet() <-chan struct{} {", "}", readerSymbol, () -> writer @@ -702,9 +703,9 @@ eventUnionSymbol, getEventStreamDeserializerName(eventStream, writer.openBlock("func (r $P) safeClose() {", "}", readerSymbol, () -> writer .write(""" - close(r.done) - r.eventStream.Close() - """)).write(""); + close(r.done) + r.eventStream.Close() + """)).write(""); writer.openBlock("func (r $P) Err() error {", "}", readerSymbol, () -> writer .write("return r.err.Err()")).write(""); @@ -724,30 +725,30 @@ private static void generateEventStreamReaderMessageWrapper( var interfaceMethod = "is" + StringUtils.capitalize(readerEventWrapperInterface.getName()); writer.write(""" - type $T interface { - $L() - } - """, readerEventWrapperInterface, interfaceMethod); + type $T interface { + $L() + } + """, readerEventWrapperInterface, interfaceMethod); var readerEventWrapperMessageType = getReaderEventWrapperMessageType(symbolProvider, eventStream, service); writer.write(""" - type $T struct { - Value $P - } - - func ($P) $L() {} - """, readerEventWrapperMessageType, eventUnionSymbol, readerEventWrapperMessageType, + type $T struct { + Value $P + } + + func ($P) $L() {} + """, readerEventWrapperMessageType, eventUnionSymbol, readerEventWrapperMessageType, interfaceMethod); var readerEventWrapperInitialResponseType = getReaderEventWrapperInitialResponseType(symbolProvider, eventStream, service); writer.write(""" - type $T struct { - Value interface{} - } - - func ($P) $L() {} - """, readerEventWrapperInitialResponseType, readerEventWrapperInitialResponseType, + type $T struct { + Value interface{} + } + + func ($P) $L() {} + """, readerEventWrapperInitialResponseType, readerEventWrapperInitialResponseType, interfaceMethod); } @@ -793,16 +794,16 @@ private static void generateEventStreamWriter( var onceErr = getSymbol("OnceErr", SmithyGoDependency.SMITHY_SYNC); writer.write(""" - encoder $P - signer $T - stream chan $T - serializationBuffer $P - signingBuffer $P - eventStream $T - done chan struct{} - closeOnce $T - err $P - """, encoderSymbol, signerInterface, asyncEventSymbol, bytesBufferSymbol, bytesBufferSymbol, + encoder $P + signer $T + stream chan $T + serializationBuffer $P + signingBuffer $P + eventStream $T + done chan struct{} + closeOnce $T + err $P + """, encoderSymbol, signerInterface, asyncEventSymbol, bytesBufferSymbol, bytesBufferSymbol, writeCloser, syncOnce, onceErr); if (withInitialMessages) { @@ -824,24 +825,24 @@ private static void generateEventStreamWriter( var onceErr = SymbolUtils.createValueSymbolBuilder("NewOnceErr", SmithyGoDependency.SMITHY_SYNC).build(); writer.write(""" - encoder: encoder, - signer: signer, - stream: make(chan $T), - eventStream: stream, - done: make(chan struct{}), - err: $T(), - serializationBuffer: $T(nil), - signingBuffer: $T(nil), - """, asyncEventSymbol, onceErr, bytesNewBuffer, bytesNewBuffer); + encoder: encoder, + signer: signer, + stream: make(chan $T), + eventStream: stream, + done: make(chan struct{}), + err: $T(), + serializationBuffer: $T(nil), + signingBuffer: $T(nil), + """, asyncEventSymbol, onceErr, bytesNewBuffer, bytesNewBuffer); if (withInitialMessages) { writer.write("initialRequestSerializer: irs,"); } }).write("") .write(""" - go w.writeStream() - - return w - """)).write(""); + go w.writeStream() + + return w + """)).write(""); Symbol contextSymbol = SymbolUtils.createValueSymbolBuilder("Context", SmithyGoDependency.CONTEXT).build(); @@ -858,17 +859,17 @@ private static void generateEventStreamWriter( writer.openBlock("func (w $P) send(ctx $P, event $P) error {", "}", writerSymbol, contextSymbol, eventSymbol, () -> { writer.write(""" - if err := w.err.Err(); err != nil { - return err - } - - resultCh := make(chan error) - - wrapped := $T{ - Event: event, - Result: resultCh, - } - """, asyncEventSymbol); + if err := w.err.Err(); err != nil { + return err + } + + resultCh := make(chan error) + + wrapped := $T{ + Event: event, + Result: resultCh, + } + """, asyncEventSymbol); Symbol errorfSymbol = SymbolUtils.createValueSymbolBuilder("Errorf", SmithyGoDependency.FMT) .build(); @@ -876,22 +877,22 @@ private static void generateEventStreamWriter( writer.openBlock("select {", "}", () -> writer .write(""" - case w.stream <- wrapped: - case <-ctx.Done(): - return ctx.Err() - case <-w.done: - return $T($S) - """, errorfSymbol, streamClosedError)).write(""); + case w.stream <- wrapped: + case <-ctx.Done(): + return ctx.Err() + case <-w.done: + return $T($S) + """, errorfSymbol, streamClosedError)).write(""); writer.openBlock("select {", "}", () -> writer .write(""" - case err := <-resultCh: - return err - case <-ctx.Done(): - return ctx.Err() - case <-w.done: - return $T($S) - """, errorfSymbol, streamClosedError)).write(""); + case err := <-resultCh: + return err + case <-ctx.Done(): + return ctx.Err() + case <-w.done: + return $T($S) + """, errorfSymbol, streamClosedError)).write(""); }).write(""); writer.openBlock("func (w $P) writeStream() {", "}", writerSymbol, () -> writer @@ -913,17 +914,17 @@ private static void generateEventStreamWriter( Runnable returnErr = () -> writer.openBlock("if err != nil {", "}", () -> writer.write("return err")) .write(""); writer.writeDocs(""" - serializedEvent returned bytes refers to an underlying byte buffer and must not escape - this writeEvent scope without first copying. Any previous bytes stored in the buffer - are cleared by this call. - """); + serializedEvent returned bytes refers to an underlying byte buffer and must not escape + this writeEvent scope without first copying. Any previous bytes stored in the buffer + are cleared by this call. + """); writer.write("serializedEvent, err := w.serializeEvent(event)"); returnErr.run(); writer.writeDocs(""" - signedEvent returned bytes refers to an underlying byte buffer and must not escape - this writeEvent scope without first copying. Any previous bytes stored in the buffer - are cleared by this call. - """); + signedEvent returned bytes refers to an underlying byte buffer and must not escape + this writeEvent scope without first copying. Any previous bytes stored in the buffer + are cleared by this call. + """); writer.write("signedEvent, err := w.signEvent(serializedEvent)"); returnErr.run(); writer.writeDocs("bytes are now copied to the underlying stream writer"); @@ -943,34 +944,34 @@ private static void generateEventStreamWriter( service); var errorf = getSymbol("Errorf", SmithyGoDependency.FMT, false); writer.write(""" - switch ev := event.(type) { - case $P: - if err := w.initialRequestSerializer(ev.Value, &eventMessage); err != nil { - return nil, err - } - case $P: - if err := $L(ev.Value, &eventMessage); err != nil { - return nil, err - } - default: - return nil, $T("unknown event wrapper type: %v", event) - } - """, initialRequestType, messageEventType, errorf); + switch ev := event.(type) { + case $P: + if err := w.initialRequestSerializer(ev.Value, &eventMessage); err != nil { + return nil, err + } + case $P: + if err := $L(ev.Value, &eventMessage); err != nil { + return nil, err + } + default: + return nil, $T("unknown event wrapper type: %v", event) + } + """, initialRequestType, messageEventType, errorf); } else { writer.write(""" - if err := $L(event, &eventMessage); err != nil { - return nil, err - } - """, + if err := $L(event, &eventMessage); err != nil { + return nil, err + } + """, getEventStreamSerializerName(eventStream, service, context.getProtocolName())); } writer.write(""" - if err := w.encoder.Encode(w.serializationBuffer, eventMessage); err != nil { - return nil, err - } - - return w.serializationBuffer.Bytes(), nil"""); + if err := w.encoder.Encode(w.serializationBuffer, eventMessage); err != nil { + return nil, err + } + + return w.serializationBuffer.Bytes(), nil"""); }).write(""); writer.openBlock("func (w $P) signEvent(payload []byte) ([]byte, error) {", "}", writerSymbol, () -> { @@ -990,7 +991,7 @@ private static void generateEventStreamWriter( getEventStreamSymbol("EncodeHeaders", false), () -> writer.write("return nil, err")).write("") .write("sig, err := w.signer.GetSignature(context.Background(), headers.Bytes(), " - + "msg.Payload, date)") + + "msg.Payload, date)") .openBlock("if err != nil {", "}", () -> writer .write("return nil, err")).write("") .write("msg.Headers.Set($T, $T(sig))", chunkSignatureHeader, bytesValue).write("") @@ -1004,9 +1005,9 @@ private static void generateEventStreamWriter( .openBlock("if cErr := w.eventStream.Close(); cErr != nil && err == nil {", "}", () -> writer.write("err = cErr"))).write("") .write(""" - // Per the protocol, a signed empty message is used to indicate the end of the stream, - // and that no subsequent events will be sent. - signedEvent, err := w.signEvent([]byte{})""") + // Per the protocol, a signed empty message is used to indicate the end of the stream, + // and that no subsequent events will be sent. + signedEvent, err := w.signEvent([]byte{})""") .openBlock("if err != nil {", "}", () -> writer.write("return err")).write("") .write("_, err = io.Copy(w.eventStream, bytes.NewReader(signedEvent))") .write("return err")).write(""); @@ -1035,30 +1036,30 @@ private static void generateEventStreamWriterMessageWrapper( var writerEventWrapperInterface = getWriterEventWrapperInterface(symbolProvider, eventStream, service); var interfaceMethod = "is" + StringUtils.capitalize(writerEventWrapperInterface.getName()); writer.write(""" - type $T interface { - $L() - } - """, writerEventWrapperInterface, interfaceMethod); + type $T interface { + $L() + } + """, writerEventWrapperInterface, interfaceMethod); var writerEventWrapperMessageType = getWriterEventWrapperMessageType(symbolProvider, eventStream, service); writer.write(""" - type $T struct { - Value $P - } - - func ($P) $L() {} - """, writerEventWrapperMessageType, eventUnionSymbol, writerEventWrapperMessageType, + type $T struct { + Value $P + } + + func ($P) $L() {} + """, writerEventWrapperMessageType, eventUnionSymbol, writerEventWrapperMessageType, interfaceMethod); var writerEventWrapperInitialRequestType = getWriterEventWrapperInitialRequestType(symbolProvider, eventStream, service); writer.write(""" - type $T struct { - Value interface{} - } - - func ($P) $L() {} - """, writerEventWrapperInitialRequestType, writerEventWrapperInitialRequestType, interfaceMethod); + type $T struct { + Value interface{} + } + + func ($P) $L() {} + """, writerEventWrapperInitialRequestType, writerEventWrapperInitialRequestType, interfaceMethod); } private static Symbol getWriterEventWrapperInterface( @@ -1142,10 +1143,10 @@ public static void generateEventStreamSerializer( getEventStreamSymbol("Message"), () -> { Symbol errof = getSymbol("Errorf", SmithyGoDependency.FMT, false); writer.write(""" - if v == nil { - return $T("unexpected serialization of nil %T", v) - } - """, errof) + if v == nil { + return $T("unexpected serialization of nil %T", v) + } + """, errof) .write("") .openBlock("switch vv := v.(type) {", "}", () -> { for (MemberShape member : eventUnion.members()) { @@ -1164,9 +1165,9 @@ public static void generateEventStreamSerializer( model.expectShape(member.getTarget()), "vv.Value"))); } writer.write(""" - default: - return $T("unexpected event message type: %v", v) - """, errof); + default: + return $T("unexpected event message type: %v", v) + """, errof); }); }); } @@ -1194,10 +1195,10 @@ public static void generateEventMessageSerializer( writer.openBlock("func $L(v $P, msg $P) error {", "}", serializerName, symbolProvider.toSymbol(targetShape), getEventStreamSymbol("Message"), () -> { writer.write(""" - if v == nil { - return $T("unexpected serialization of nil %T", v) - } - """, errorf).write("") + if v == nil { + return $T("unexpected serialization of nil %T", v) + } + """, errorf).write("") .write("msg.Headers.Set($T, $T($T))", messageTypeHeader, stringValue, eventMessageType); var headerBindings = targetShape.members().stream() @@ -1245,7 +1246,7 @@ public static void generateEventMessageSerializer( break; default: throw new CodegenException("unexpected event payload shape: " - + payloadTarget.getType()); + + payloadTarget.getType()); } } } else { @@ -1270,17 +1271,17 @@ public static void generateEventStreamDeserializer(GenerationContext context, Un var equalFold = SymbolUtils.createValueSymbolBuilder("EqualFold", SmithyGoDependency.STRINGS).build(); writer.write(""" - if v == nil { - return $T("unexpected serialization of nil %T", v) - } - """, errof) + if v == nil { + return $T("unexpected serialization of nil %T", v) + } + """, errof) .write("") .write(""" - eventType := msg.Headers.Get($T) - if eventType == nil { - return $T("%s event header not present", $T) - } - """, eventTypeHeader, errof, eventTypeHeader).write("") + eventType := msg.Headers.Get($T) + if eventType == nil { + return $T("%s event header not present", $T) + } + """, eventTypeHeader, errof, eventTypeHeader).write("") .openBlock("switch {", "}", () -> { var members = eventUnion.members().stream() .filter(ms -> ms.getMemberTrait(model, ErrorTrait.class).isEmpty()) @@ -1297,27 +1298,27 @@ public static void generateEventStreamDeserializer(GenerationContext context, Un eventUnionSymbol.getNamespace()) .build(); writer.write(""" - vv := &$T{} - if err := $L(&vv.Value, msg); err != nil { - return err - } - *v = vv - return nil - """, memberSymbol, messageDeserializerName); + vv := &$T{} + if err := $L(&vv.Value, msg); err != nil { + return err + } + *v = vv + return nil + """, memberSymbol, messageDeserializerName); }); } var newBuffer = getSymbol("NewBuffer", SmithyGoDependency.BYTES); var newEncoder = getEventStreamSymbol("NewEncoder"); writer.write(""" - default: - buffer := $T(nil) - $T().Encode(buffer, *msg) - *v = &$T{ - Tag: eventType.String(), - Value: buffer.Bytes(), - } - return nil - """, newBuffer, newEncoder, SymbolUtils. + default: + buffer := $T(nil) + $T().Encode(buffer, *msg) + *v = &$T{ + Tag: eventType.String(), + Value: buffer.Bytes(), + } + return nil + """, newBuffer, newEncoder, SymbolUtils. createValueSymbolBuilder("UnknownUnionMember", eventUnionSymbol.getNamespace()).build()); }); @@ -1344,11 +1345,11 @@ public static void generateEventStreamExceptionDeserializer( writer.openBlock("func $L(msg $P) error {", "}", deserializerName, getEventStreamSymbol("Message"), () -> { writer.write(""" - exceptionType := msg.Headers.Get($T) - if exceptionType == nil { - return $T("%s event header not present", $T) - } - """, exceptionTypeHeader, errorf, exceptionTypeHeader).write(""); + exceptionType := msg.Headers.Get($T) + if exceptionType == nil { + return $T("%s event header not present", $T) + } + """, exceptionTypeHeader, errorf, exceptionTypeHeader).write(""); var errorMemberShapes = eventUnion.members().stream() .filter(ms -> ms.getMemberTrait(model, ErrorTrait.class).isPresent()) @@ -1422,10 +1423,10 @@ public static void generateEventMessageDeserializer( writer.openBlock("func $L(v $P, msg $P) error {", "}", deserializerName, symbolProvider.toSymbol(targetShape), getEventStreamSymbol("Message"), () -> { writer.write(""" - if v == nil { - return $T("unexpected serialization of nil %T", v) - } - """, errorf).write(""); + if v == nil { + return $T("unexpected serialization of nil %T", v) + } + """, errorf).write(""); var headerBindings = targetShape.members().stream() .filter(memberShape -> memberShape.hasTrait(EventHeaderTrait.class)) @@ -1442,7 +1443,7 @@ public static void generateEventMessageDeserializer( var dest = String.format("v.%s", symbolProvider.toMemberName(headerBinding)); new HeaderShapeDeserVisitor(writer, model, headerBinding, dest, - headerBinding.getMemberName(), "msg").writeDeserializer(); + headerBinding.getMemberName(), "msg.Headers").writeDeserializer(); } if (payloadBinding.isPresent()) { var memberShape = payloadBinding.get(); @@ -1465,9 +1466,9 @@ public static void generateEventMessageDeserializer( case BLOB: writer.openBlock("if msg.Payload != nil {", "}", () -> { writer.write(""" - bsv := make([]byte, len(msg.Payload)) - copy(bsv, msg.Payload) - """); + bsv := make([]byte, len(msg.Payload)) + copy(bsv, msg.Payload) + """); var pointable = CodegenUtils.getAsPointerIfPointable(model, writer, pointableIndex, memberShape, "bsv"); writer.write("$L = $L", String.format("v.%s", @@ -1477,7 +1478,7 @@ public static void generateEventMessageDeserializer( break; default: throw new CodegenException("unexpected event payload shape: " - + payloadTarget.getType()); + + payloadTarget.getType()); } } } else { @@ -1546,12 +1547,12 @@ public static String getAsyncWriteReporterName(Shape shape, ServiceShape service public static String getEventStreamWriterImplName(Shape shape, ServiceShape serviceShape) { var name = shape.getId().getName(serviceShape); - return StringUtils.uncapitalize(name); + return StringUtils.uncapitalize(name) + "Writer"; } public static String getEventStreamReaderImplName(Shape shape, ServiceShape serviceShape) { var name = shape.getId().getName(serviceShape); - return StringUtils.uncapitalize(name); + return StringUtils.uncapitalize(name) + "Reader"; } private static Symbol getEventStreamSymbol(String name) { @@ -1616,19 +1617,19 @@ public static void generateEventMessageRequestSerializer( () -> { var inputSymbol = symbolProvider.toSymbol(inputShape); writer.write(""" - if i == nil { - return $T("event message serializer expects non-nil %T", ($P)(nil)) - } - - v, ok := i.($P) - if !ok { - return $T("unexpected serialization of %T", i) - } - """, errorf, inputSymbol, inputSymbol, errorf).write("") + if i == nil { + return $T("event message serializer expects non-nil %T", ($P)(nil)) + } + + v, ok := i.($P) + if !ok { + return $T("unexpected serialization of %T", i) + } + """, errorf, inputSymbol, inputSymbol, errorf).write("") .write(""" - msg.Headers.Set($T, $T($T)) - msg.Headers.Set($T, $T($S)) - """, + msg.Headers.Set($T, $T($T)) + msg.Headers.Set($T, $T($S)) + """, messageTypeHeader, stringValue, eventMessageType, eventTypeHeader, stringValue, "initial-request" ).write(""); @@ -1678,15 +1679,17 @@ private static String getSerDeName( ToShapeId toShapeId, ServiceShape serviceShape, String protocolName, String name ) { return StringUtils.uncapitalize(protocolName) + name - + toShapeId.toShapeId().getName(serviceShape); + + toShapeId.toShapeId().getName(serviceShape); } public static void writeOperationSerializerMiddlewareEventStreamSetup( GenerationContext context, - EventStreamInfo info + EventStreamInfo info, + String encoderIdentifier ) { context.getWriter().get() - .write("restEncoder.SetHeader(\"Content-Type\").String($S)", "application/vnd.amazon.eventstream") + .write("$L.SetHeader(\"Content-Type\").String($S)", encoderIdentifier, + "application/vnd.amazon.eventstream") .write(""); } @@ -1902,15 +1905,15 @@ private void writeTypeDeserializer(Symbol apiHeaderType, Symbol concreteType) { private void writeTypeDeserializer(Symbol apiHeaderType, Symbol concreteType, Runnable setter) { writer.openBlock("{", "}", () -> { var errorf = SymbolUtils.createValueSymbolBuilder("Errorf", SmithyGoDependency.FMT).build(); - writer.write("headerValue := $L.Get($S)", dest, headerName) + writer.write("headerValue := $L.Get($S)", dataSource, headerName) .openBlock("if headerValue != nil {", "}", () -> { writer.write("hv, ok := headerValue.($P)", apiHeaderType) .write(""" - if !ok { - return $T("unexpected event header %s with type %T:", $S, headerValue) - } - """, errorf, headerName).write("") - .write("ihv := headerValue.Get().($P)", concreteType); + if !ok { + return $T("unexpected event header %s with type %T:", $S, headerValue) + } + """, errorf, headerName).write("") + .write("ihv := hv.Get().($P)", concreteType); setter.run(); }); }).write(""); diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/JsonRpcProtocolGenerator.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/JsonRpcProtocolGenerator.java index b1fae522853..1a8c371970e 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/JsonRpcProtocolGenerator.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/JsonRpcProtocolGenerator.java @@ -203,7 +203,8 @@ public void generateEventStreamComponents(GenerationContext context) { @Override protected void writeOperationSerializerMiddlewareEventStreamSetup(GenerationContext context, EventStreamInfo info) { - AwsEventStreamUtils.writeOperationSerializerMiddlewareEventStreamSetup(context, info); + AwsEventStreamUtils.writeOperationSerializerMiddlewareEventStreamSetup(context, info, + "httpBindingEncoder"); } @Override @@ -315,7 +316,7 @@ protected void generateEventStreamDeserializers( payloadTarget, ctx.getService(), getProtocolName()); var ctxWriter = ctx.getWriter().get(); ctxWriter.openBlock("if err := $L(&$L, shape); err != nil {", "}", functionName, operand, - () -> handleDecodeError(ctxWriter)) + () -> handleDecodeError(ctxWriter)) .write("return nil"); }); @@ -348,7 +349,7 @@ protected void generateEventStreamDeserializers( AwsProtocolUtils.initializeJsonEventMessageDeserializer(ctx, "nil,"); var ctxWriter = ctx.getWriter().get(); ctxWriter.openBlock("if err := $L(&$L, shape); err != nil {", "}", functionName, operand, - () -> handleDecodeError(ctxWriter, "nil,")) + () -> handleDecodeError(ctxWriter, "nil,")) .write("return v, nil"); }); var initialMessageMembers = streamInfo.getInitialMessageMembers() diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java index af16035bf20..7ce44c53f5d 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java @@ -519,7 +519,7 @@ public void generateEventStreamComponents(GenerationContext context) { @Override protected void writeOperationSerializerMiddlewareEventStreamSetup(GenerationContext context, EventStreamInfo info) { - AwsEventStreamUtils.writeOperationSerializerMiddlewareEventStreamSetup(context, info); + AwsEventStreamUtils.writeOperationSerializerMiddlewareEventStreamSetup(context, info, "restEncoder"); } @Override diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestXmlProtocolGenerator.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestXmlProtocolGenerator.java index da595ddbf90..b5a3488bdcd 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestXmlProtocolGenerator.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestXmlProtocolGenerator.java @@ -475,7 +475,7 @@ public void generateEventStreamComponents(GenerationContext context) { @Override protected void writeOperationSerializerMiddlewareEventStreamSetup(GenerationContext context, EventStreamInfo info) { - AwsEventStreamUtils.writeOperationSerializerMiddlewareEventStreamSetup(context, info); + AwsEventStreamUtils.writeOperationSerializerMiddlewareEventStreamSetup(context, info, "restEncoder"); } @Override diff --git a/service/kinesis/eventstream.go b/service/kinesis/eventstream.go index 73b053e7f06..c044348a4f0 100644 --- a/service/kinesis/eventstream.go +++ b/service/kinesis/eventstream.go @@ -45,7 +45,7 @@ type subscribeToShardEventStreamReadEventInitialResponse struct { func (*subscribeToShardEventStreamReadEventInitialResponse) isSubscribeToShardEventStreamReadEvent() { } -type subscribeToShardEventStream struct { +type subscribeToShardEventStreamReader struct { stream chan types.SubscribeToShardEventStream decoder *eventstream.Decoder eventStream io.ReadCloser @@ -57,8 +57,8 @@ type subscribeToShardEventStream struct { initialResponse chan interface{} } -func newSubscribeToShardEventStream(readCloser io.ReadCloser, decoder *eventstream.Decoder, ird func(*eventstream.Message) (interface{}, error)) *subscribeToShardEventStream { - w := &subscribeToShardEventStream{ +func newSubscribeToShardEventStreamWriter(readCloser io.ReadCloser, decoder *eventstream.Decoder, ird func(*eventstream.Message) (interface{}, error)) *subscribeToShardEventStreamReader { + w := &subscribeToShardEventStreamReader{ stream: make(chan types.SubscribeToShardEventStream), decoder: decoder, eventStream: readCloser, @@ -74,11 +74,11 @@ func newSubscribeToShardEventStream(readCloser io.ReadCloser, decoder *eventstre return w } -func (r *subscribeToShardEventStream) Events() <-chan types.SubscribeToShardEventStream { +func (r *subscribeToShardEventStreamReader) Events() <-chan types.SubscribeToShardEventStream { return r.stream } -func (r *subscribeToShardEventStream) readEventStream() { +func (r *subscribeToShardEventStreamReader) readEventStream() { defer r.Close() defer close(r.stream) @@ -128,7 +128,7 @@ func (r *subscribeToShardEventStream) readEventStream() { } } -func (r *subscribeToShardEventStream) deserializeEventMessage(msg *eventstream.Message) (subscribeToShardEventStreamReadEvent, error) { +func (r *subscribeToShardEventStreamReader) deserializeEventMessage(msg *eventstream.Message) (subscribeToShardEventStreamReadEvent, error) { messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader) if messageType == nil { return nil, fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader) @@ -182,26 +182,26 @@ func (r *subscribeToShardEventStream) deserializeEventMessage(msg *eventstream.M } } -func (r *subscribeToShardEventStream) ErrorSet() <-chan struct{} { +func (r *subscribeToShardEventStreamReader) ErrorSet() <-chan struct{} { return r.err.ErrorSet() } -func (r *subscribeToShardEventStream) Close() error { +func (r *subscribeToShardEventStreamReader) Close() error { r.closeOnce.Do(r.safeClose) return r.Err() } -func (r *subscribeToShardEventStream) safeClose() { +func (r *subscribeToShardEventStreamReader) safeClose() { close(r.done) r.eventStream.Close() } -func (r *subscribeToShardEventStream) Err() error { +func (r *subscribeToShardEventStreamReader) Err() error { return r.err.Err() } -func (r *subscribeToShardEventStream) Closed() <-chan struct{} { +func (r *subscribeToShardEventStreamReader) Closed() <-chan struct{} { return r.done } @@ -251,7 +251,7 @@ func (m *awsAwsjson11_deserializeOpEventStreamSubscribeToShard) HandleDeserializ out.Result = output } - eventReader := newSubscribeToShardEventStream( + eventReader := newSubscribeToShardEventStreamWriter( deserializeOutput.Body, eventstream.NewDecoder(func(options *eventstream.DecoderOptions) { options.Logger = logger diff --git a/service/lexruntimev2/eventstream.go b/service/lexruntimev2/eventstream.go index 6d9509501bb..bec14364018 100644 --- a/service/lexruntimev2/eventstream.go +++ b/service/lexruntimev2/eventstream.go @@ -62,7 +62,7 @@ func (e asyncStartConversationRequestEventStream) ReportResult(cancel <-chan str } } -type startConversationRequestEventStream struct { +type startConversationRequestEventStreamWriter struct { encoder *eventstream.Encoder signer eventStreamSigner stream chan asyncStartConversationRequestEventStream @@ -74,8 +74,8 @@ type startConversationRequestEventStream struct { err *smithysync.OnceErr } -func newStartConversationRequestEventStream(stream io.WriteCloser, encoder *eventstream.Encoder, signer eventStreamSigner) *startConversationRequestEventStream { - w := &startConversationRequestEventStream{ +func newStartConversationRequestEventStreamReader(stream io.WriteCloser, encoder *eventstream.Encoder, signer eventStreamSigner) *startConversationRequestEventStreamWriter { + w := &startConversationRequestEventStreamWriter{ encoder: encoder, signer: signer, stream: make(chan asyncStartConversationRequestEventStream), @@ -92,11 +92,11 @@ func newStartConversationRequestEventStream(stream io.WriteCloser, encoder *even } -func (w *startConversationRequestEventStream) Send(ctx context.Context, event types.StartConversationRequestEventStream) error { +func (w *startConversationRequestEventStreamWriter) Send(ctx context.Context, event types.StartConversationRequestEventStream) error { return w.send(ctx, event) } -func (w *startConversationRequestEventStream) send(ctx context.Context, event types.StartConversationRequestEventStream) error { +func (w *startConversationRequestEventStreamWriter) send(ctx context.Context, event types.StartConversationRequestEventStream) error { if err := w.err.Err(); err != nil { return err } @@ -129,7 +129,7 @@ func (w *startConversationRequestEventStream) send(ctx context.Context, event ty } -func (w *startConversationRequestEventStream) writeStream() { +func (w *startConversationRequestEventStreamWriter) writeStream() { defer w.Close() for { @@ -152,7 +152,7 @@ func (w *startConversationRequestEventStream) writeStream() { } } -func (w *startConversationRequestEventStream) writeEvent(event types.StartConversationRequestEventStream) error { +func (w *startConversationRequestEventStreamWriter) writeEvent(event types.StartConversationRequestEventStream) error { // serializedEvent returned bytes refers to an underlying byte buffer and must not // escape this writeEvent scope without first copying. Any previous bytes stored in // the buffer are cleared by this call. @@ -174,7 +174,7 @@ func (w *startConversationRequestEventStream) writeEvent(event types.StartConver return err } -func (w *startConversationRequestEventStream) serializeEvent(event types.StartConversationRequestEventStream) ([]byte, error) { +func (w *startConversationRequestEventStreamWriter) serializeEvent(event types.StartConversationRequestEventStream) ([]byte, error) { w.serializationBuffer.Reset() eventMessage := eventstream.Message{} @@ -190,7 +190,7 @@ func (w *startConversationRequestEventStream) serializeEvent(event types.StartCo return w.serializationBuffer.Bytes(), nil } -func (w *startConversationRequestEventStream) signEvent(payload []byte) ([]byte, error) { +func (w *startConversationRequestEventStreamWriter) signEvent(payload []byte) ([]byte, error) { w.signingBuffer.Reset() date := time.Now().UTC() @@ -218,7 +218,7 @@ func (w *startConversationRequestEventStream) signEvent(payload []byte) ([]byte, return w.signingBuffer.Bytes(), nil } -func (w *startConversationRequestEventStream) closeStream() (err error) { +func (w *startConversationRequestEventStreamWriter) closeStream() (err error) { defer func() { if cErr := w.eventStream.Close(); cErr != nil && err == nil { err = cErr @@ -236,24 +236,24 @@ func (w *startConversationRequestEventStream) closeStream() (err error) { return err } -func (w *startConversationRequestEventStream) ErrorSet() <-chan struct{} { +func (w *startConversationRequestEventStreamWriter) ErrorSet() <-chan struct{} { return w.err.ErrorSet() } -func (w *startConversationRequestEventStream) Close() error { +func (w *startConversationRequestEventStreamWriter) Close() error { w.closeOnce.Do(w.safeClose) return w.Err() } -func (w *startConversationRequestEventStream) safeClose() { +func (w *startConversationRequestEventStreamWriter) safeClose() { close(w.done) } -func (w *startConversationRequestEventStream) Err() error { +func (w *startConversationRequestEventStreamWriter) Err() error { return w.err.Err() } -type startConversationResponseEventStream struct { +type startConversationResponseEventStreamReader struct { stream chan types.StartConversationResponseEventStream decoder *eventstream.Decoder eventStream io.ReadCloser @@ -263,8 +263,8 @@ type startConversationResponseEventStream struct { closeOnce sync.Once } -func newStartConversationResponseEventStream(readCloser io.ReadCloser, decoder *eventstream.Decoder) *startConversationResponseEventStream { - w := &startConversationResponseEventStream{ +func newStartConversationResponseEventStreamWriter(readCloser io.ReadCloser, decoder *eventstream.Decoder) *startConversationResponseEventStreamReader { + w := &startConversationResponseEventStreamReader{ stream: make(chan types.StartConversationResponseEventStream), decoder: decoder, eventStream: readCloser, @@ -278,11 +278,11 @@ func newStartConversationResponseEventStream(readCloser io.ReadCloser, decoder * return w } -func (r *startConversationResponseEventStream) Events() <-chan types.StartConversationResponseEventStream { +func (r *startConversationResponseEventStreamReader) Events() <-chan types.StartConversationResponseEventStream { return r.stream } -func (r *startConversationResponseEventStream) readEventStream() { +func (r *startConversationResponseEventStreamReader) readEventStream() { defer r.Close() defer close(r.stream) @@ -317,7 +317,7 @@ func (r *startConversationResponseEventStream) readEventStream() { } } -func (r *startConversationResponseEventStream) deserializeEventMessage(msg *eventstream.Message) (types.StartConversationResponseEventStream, error) { +func (r *startConversationResponseEventStreamReader) deserializeEventMessage(msg *eventstream.Message) (types.StartConversationResponseEventStream, error) { messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader) if messageType == nil { return nil, fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader) @@ -358,26 +358,26 @@ func (r *startConversationResponseEventStream) deserializeEventMessage(msg *even } } -func (r *startConversationResponseEventStream) ErrorSet() <-chan struct{} { +func (r *startConversationResponseEventStreamReader) ErrorSet() <-chan struct{} { return r.err.ErrorSet() } -func (r *startConversationResponseEventStream) Close() error { +func (r *startConversationResponseEventStreamReader) Close() error { r.closeOnce.Do(r.safeClose) return r.Err() } -func (r *startConversationResponseEventStream) safeClose() { +func (r *startConversationResponseEventStreamReader) safeClose() { close(r.done) r.eventStream.Close() } -func (r *startConversationResponseEventStream) Err() error { +func (r *startConversationResponseEventStreamReader) Err() error { return r.err.Err() } -func (r *startConversationResponseEventStream) Closed() <-chan struct{} { +func (r *startConversationResponseEventStreamReader) Closed() <-chan struct{} { return r.done } @@ -424,7 +424,7 @@ func (m *awsRestjson1_deserializeOpEventStreamStartConversation) HandleDeseriali requestSignature, ) - eventWriter := newStartConversationRequestEventStream( + eventWriter := newStartConversationRequestEventStreamReader( eventstreamapi.GetInputStreamWriter(ctx), eventstream.NewEncoder(func(options *eventstream.EncoderOptions) { options.Logger = logger @@ -459,7 +459,7 @@ func (m *awsRestjson1_deserializeOpEventStreamStartConversation) HandleDeseriali out.Result = output } - eventReader := newStartConversationResponseEventStream( + eventReader := newStartConversationResponseEventStreamWriter( deserializeOutput.Body, eventstream.NewDecoder(func(options *eventstream.DecoderOptions) { options.Logger = logger diff --git a/service/s3/eventstream.go b/service/s3/eventstream.go index 38d579a3a4f..0e267c92730 100644 --- a/service/s3/eventstream.go +++ b/service/s3/eventstream.go @@ -28,7 +28,7 @@ type SelectObjectContentEventStreamReader interface { Err() error } -type selectObjectContentEventStream struct { +type selectObjectContentEventStreamReader struct { stream chan types.SelectObjectContentEventStream decoder *eventstream.Decoder eventStream io.ReadCloser @@ -38,8 +38,8 @@ type selectObjectContentEventStream struct { closeOnce sync.Once } -func newSelectObjectContentEventStream(readCloser io.ReadCloser, decoder *eventstream.Decoder) *selectObjectContentEventStream { - w := &selectObjectContentEventStream{ +func newSelectObjectContentEventStreamWriter(readCloser io.ReadCloser, decoder *eventstream.Decoder) *selectObjectContentEventStreamReader { + w := &selectObjectContentEventStreamReader{ stream: make(chan types.SelectObjectContentEventStream), decoder: decoder, eventStream: readCloser, @@ -53,11 +53,11 @@ func newSelectObjectContentEventStream(readCloser io.ReadCloser, decoder *events return w } -func (r *selectObjectContentEventStream) Events() <-chan types.SelectObjectContentEventStream { +func (r *selectObjectContentEventStreamReader) Events() <-chan types.SelectObjectContentEventStream { return r.stream } -func (r *selectObjectContentEventStream) readEventStream() { +func (r *selectObjectContentEventStreamReader) readEventStream() { defer r.Close() defer close(r.stream) @@ -92,7 +92,7 @@ func (r *selectObjectContentEventStream) readEventStream() { } } -func (r *selectObjectContentEventStream) deserializeEventMessage(msg *eventstream.Message) (types.SelectObjectContentEventStream, error) { +func (r *selectObjectContentEventStreamReader) deserializeEventMessage(msg *eventstream.Message) (types.SelectObjectContentEventStream, error) { messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader) if messageType == nil { return nil, fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader) @@ -133,26 +133,26 @@ func (r *selectObjectContentEventStream) deserializeEventMessage(msg *eventstrea } } -func (r *selectObjectContentEventStream) ErrorSet() <-chan struct{} { +func (r *selectObjectContentEventStreamReader) ErrorSet() <-chan struct{} { return r.err.ErrorSet() } -func (r *selectObjectContentEventStream) Close() error { +func (r *selectObjectContentEventStreamReader) Close() error { r.closeOnce.Do(r.safeClose) return r.Err() } -func (r *selectObjectContentEventStream) safeClose() { +func (r *selectObjectContentEventStreamReader) safeClose() { close(r.done) r.eventStream.Close() } -func (r *selectObjectContentEventStream) Err() error { +func (r *selectObjectContentEventStreamReader) Err() error { return r.err.Err() } -func (r *selectObjectContentEventStream) Closed() <-chan struct{} { +func (r *selectObjectContentEventStreamReader) Closed() <-chan struct{} { return r.done } @@ -202,7 +202,7 @@ func (m *awsRestxml_deserializeOpEventStreamSelectObjectContent) HandleDeseriali out.Result = output } - eventReader := newSelectObjectContentEventStream( + eventReader := newSelectObjectContentEventStreamWriter( deserializeOutput.Body, eventstream.NewDecoder(func(options *eventstream.DecoderOptions) { options.Logger = logger diff --git a/service/transcribestreaming/eventstream.go b/service/transcribestreaming/eventstream.go index f33c4a7fdd1..b6814886e25 100644 --- a/service/transcribestreaming/eventstream.go +++ b/service/transcribestreaming/eventstream.go @@ -71,7 +71,7 @@ func (e asyncAudioStream) ReportResult(cancel <-chan struct{}, err error) bool { } } -type audioStream struct { +type audioStreamWriter struct { encoder *eventstream.Encoder signer eventStreamSigner stream chan asyncAudioStream @@ -83,8 +83,8 @@ type audioStream struct { err *smithysync.OnceErr } -func newAudioStream(stream io.WriteCloser, encoder *eventstream.Encoder, signer eventStreamSigner) *audioStream { - w := &audioStream{ +func newAudioStreamReader(stream io.WriteCloser, encoder *eventstream.Encoder, signer eventStreamSigner) *audioStreamWriter { + w := &audioStreamWriter{ encoder: encoder, signer: signer, stream: make(chan asyncAudioStream), @@ -101,11 +101,11 @@ func newAudioStream(stream io.WriteCloser, encoder *eventstream.Encoder, signer } -func (w *audioStream) Send(ctx context.Context, event types.AudioStream) error { +func (w *audioStreamWriter) Send(ctx context.Context, event types.AudioStream) error { return w.send(ctx, event) } -func (w *audioStream) send(ctx context.Context, event types.AudioStream) error { +func (w *audioStreamWriter) send(ctx context.Context, event types.AudioStream) error { if err := w.err.Err(); err != nil { return err } @@ -138,7 +138,7 @@ func (w *audioStream) send(ctx context.Context, event types.AudioStream) error { } -func (w *audioStream) writeStream() { +func (w *audioStreamWriter) writeStream() { defer w.Close() for { @@ -161,7 +161,7 @@ func (w *audioStream) writeStream() { } } -func (w *audioStream) writeEvent(event types.AudioStream) error { +func (w *audioStreamWriter) writeEvent(event types.AudioStream) error { // serializedEvent returned bytes refers to an underlying byte buffer and must not // escape this writeEvent scope without first copying. Any previous bytes stored in // the buffer are cleared by this call. @@ -183,7 +183,7 @@ func (w *audioStream) writeEvent(event types.AudioStream) error { return err } -func (w *audioStream) serializeEvent(event types.AudioStream) ([]byte, error) { +func (w *audioStreamWriter) serializeEvent(event types.AudioStream) ([]byte, error) { w.serializationBuffer.Reset() eventMessage := eventstream.Message{} @@ -199,7 +199,7 @@ func (w *audioStream) serializeEvent(event types.AudioStream) ([]byte, error) { return w.serializationBuffer.Bytes(), nil } -func (w *audioStream) signEvent(payload []byte) ([]byte, error) { +func (w *audioStreamWriter) signEvent(payload []byte) ([]byte, error) { w.signingBuffer.Reset() date := time.Now().UTC() @@ -227,7 +227,7 @@ func (w *audioStream) signEvent(payload []byte) ([]byte, error) { return w.signingBuffer.Bytes(), nil } -func (w *audioStream) closeStream() (err error) { +func (w *audioStreamWriter) closeStream() (err error) { defer func() { if cErr := w.eventStream.Close(); cErr != nil && err == nil { err = cErr @@ -245,24 +245,24 @@ func (w *audioStream) closeStream() (err error) { return err } -func (w *audioStream) ErrorSet() <-chan struct{} { +func (w *audioStreamWriter) ErrorSet() <-chan struct{} { return w.err.ErrorSet() } -func (w *audioStream) Close() error { +func (w *audioStreamWriter) Close() error { w.closeOnce.Do(w.safeClose) return w.Err() } -func (w *audioStream) safeClose() { +func (w *audioStreamWriter) safeClose() { close(w.done) } -func (w *audioStream) Err() error { +func (w *audioStreamWriter) Err() error { return w.err.Err() } -type medicalTranscriptResultStream struct { +type medicalTranscriptResultStreamReader struct { stream chan types.MedicalTranscriptResultStream decoder *eventstream.Decoder eventStream io.ReadCloser @@ -272,8 +272,8 @@ type medicalTranscriptResultStream struct { closeOnce sync.Once } -func newMedicalTranscriptResultStream(readCloser io.ReadCloser, decoder *eventstream.Decoder) *medicalTranscriptResultStream { - w := &medicalTranscriptResultStream{ +func newMedicalTranscriptResultStreamWriter(readCloser io.ReadCloser, decoder *eventstream.Decoder) *medicalTranscriptResultStreamReader { + w := &medicalTranscriptResultStreamReader{ stream: make(chan types.MedicalTranscriptResultStream), decoder: decoder, eventStream: readCloser, @@ -287,11 +287,11 @@ func newMedicalTranscriptResultStream(readCloser io.ReadCloser, decoder *eventst return w } -func (r *medicalTranscriptResultStream) Events() <-chan types.MedicalTranscriptResultStream { +func (r *medicalTranscriptResultStreamReader) Events() <-chan types.MedicalTranscriptResultStream { return r.stream } -func (r *medicalTranscriptResultStream) readEventStream() { +func (r *medicalTranscriptResultStreamReader) readEventStream() { defer r.Close() defer close(r.stream) @@ -326,7 +326,7 @@ func (r *medicalTranscriptResultStream) readEventStream() { } } -func (r *medicalTranscriptResultStream) deserializeEventMessage(msg *eventstream.Message) (types.MedicalTranscriptResultStream, error) { +func (r *medicalTranscriptResultStreamReader) deserializeEventMessage(msg *eventstream.Message) (types.MedicalTranscriptResultStream, error) { messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader) if messageType == nil { return nil, fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader) @@ -367,30 +367,30 @@ func (r *medicalTranscriptResultStream) deserializeEventMessage(msg *eventstream } } -func (r *medicalTranscriptResultStream) ErrorSet() <-chan struct{} { +func (r *medicalTranscriptResultStreamReader) ErrorSet() <-chan struct{} { return r.err.ErrorSet() } -func (r *medicalTranscriptResultStream) Close() error { +func (r *medicalTranscriptResultStreamReader) Close() error { r.closeOnce.Do(r.safeClose) return r.Err() } -func (r *medicalTranscriptResultStream) safeClose() { +func (r *medicalTranscriptResultStreamReader) safeClose() { close(r.done) r.eventStream.Close() } -func (r *medicalTranscriptResultStream) Err() error { +func (r *medicalTranscriptResultStreamReader) Err() error { return r.err.Err() } -func (r *medicalTranscriptResultStream) Closed() <-chan struct{} { +func (r *medicalTranscriptResultStreamReader) Closed() <-chan struct{} { return r.done } -type transcriptResultStream struct { +type transcriptResultStreamReader struct { stream chan types.TranscriptResultStream decoder *eventstream.Decoder eventStream io.ReadCloser @@ -400,8 +400,8 @@ type transcriptResultStream struct { closeOnce sync.Once } -func newTranscriptResultStream(readCloser io.ReadCloser, decoder *eventstream.Decoder) *transcriptResultStream { - w := &transcriptResultStream{ +func newTranscriptResultStreamWriter(readCloser io.ReadCloser, decoder *eventstream.Decoder) *transcriptResultStreamReader { + w := &transcriptResultStreamReader{ stream: make(chan types.TranscriptResultStream), decoder: decoder, eventStream: readCloser, @@ -415,11 +415,11 @@ func newTranscriptResultStream(readCloser io.ReadCloser, decoder *eventstream.De return w } -func (r *transcriptResultStream) Events() <-chan types.TranscriptResultStream { +func (r *transcriptResultStreamReader) Events() <-chan types.TranscriptResultStream { return r.stream } -func (r *transcriptResultStream) readEventStream() { +func (r *transcriptResultStreamReader) readEventStream() { defer r.Close() defer close(r.stream) @@ -454,7 +454,7 @@ func (r *transcriptResultStream) readEventStream() { } } -func (r *transcriptResultStream) deserializeEventMessage(msg *eventstream.Message) (types.TranscriptResultStream, error) { +func (r *transcriptResultStreamReader) deserializeEventMessage(msg *eventstream.Message) (types.TranscriptResultStream, error) { messageType := msg.Headers.Get(eventstreamapi.MessageTypeHeader) if messageType == nil { return nil, fmt.Errorf("%s event header not present", eventstreamapi.MessageTypeHeader) @@ -495,26 +495,26 @@ func (r *transcriptResultStream) deserializeEventMessage(msg *eventstream.Messag } } -func (r *transcriptResultStream) ErrorSet() <-chan struct{} { +func (r *transcriptResultStreamReader) ErrorSet() <-chan struct{} { return r.err.ErrorSet() } -func (r *transcriptResultStream) Close() error { +func (r *transcriptResultStreamReader) Close() error { r.closeOnce.Do(r.safeClose) return r.Err() } -func (r *transcriptResultStream) safeClose() { +func (r *transcriptResultStreamReader) safeClose() { close(r.done) r.eventStream.Close() } -func (r *transcriptResultStream) Err() error { +func (r *transcriptResultStreamReader) Err() error { return r.err.Err() } -func (r *transcriptResultStream) Closed() <-chan struct{} { +func (r *transcriptResultStreamReader) Closed() <-chan struct{} { return r.done } @@ -561,7 +561,7 @@ func (m *awsRestjson1_deserializeOpEventStreamStartMedicalStreamTranscription) H requestSignature, ) - eventWriter := newAudioStream( + eventWriter := newAudioStreamReader( eventstreamapi.GetInputStreamWriter(ctx), eventstream.NewEncoder(func(options *eventstream.EncoderOptions) { options.Logger = logger @@ -596,7 +596,7 @@ func (m *awsRestjson1_deserializeOpEventStreamStartMedicalStreamTranscription) H out.Result = output } - eventReader := newMedicalTranscriptResultStream( + eventReader := newMedicalTranscriptResultStreamWriter( deserializeOutput.Body, eventstream.NewDecoder(func(options *eventstream.DecoderOptions) { options.Logger = logger @@ -678,7 +678,7 @@ func (m *awsRestjson1_deserializeOpEventStreamStartStreamTranscription) HandleDe requestSignature, ) - eventWriter := newAudioStream( + eventWriter := newAudioStreamReader( eventstreamapi.GetInputStreamWriter(ctx), eventstream.NewEncoder(func(options *eventstream.EncoderOptions) { options.Logger = logger @@ -713,7 +713,7 @@ func (m *awsRestjson1_deserializeOpEventStreamStartStreamTranscription) HandleDe out.Result = output } - eventReader := newTranscriptResultStream( + eventReader := newTranscriptResultStreamWriter( deserializeOutput.Body, eventstream.NewDecoder(func(options *eventstream.DecoderOptions) { options.Logger = logger