Skip to content

Commit

Permalink
refactor handling imds response
Browse files Browse the repository at this point in the history
  • Loading branch information
mismithhisler committed Dec 20, 2024
1 parent 5a33062 commit 88a457c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 14 deletions.
38 changes: 24 additions & 14 deletions client/fingerprint/env_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ var ec2NetSpeedTable = map[*regexp.Regexp]int{
type EnvAWSFingerprint struct {
StaticFingerprinter

// endpoint for EC2 metadata as expected by AWS SDK
// used to override IMDS endpoint for testing
endpoint string

logger log.Logger
Expand All @@ -80,7 +80,7 @@ func (f *EnvAWSFingerprint) Fingerprint(request *FingerprintRequest, response *F
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
defer cancel()

imdsClient, err := imdsClient(ctx, f.endpoint)
imdsClient, err := f.imdsClient(ctx)
if err != nil {
return fmt.Errorf("failed to setup IMDS client: %v", err)
}
Expand Down Expand Up @@ -117,15 +117,15 @@ func (f *EnvAWSFingerprint) Fingerprint(request *FingerprintRequest, response *F
if resp == nil {
continue
}
defer resp.Content.Close()

bytes, err := io.ReadAll(resp.Content)
v, err := readMetadataResponse(resp)
if err != nil {
return err
}
v := strings.TrimSpace(string(bytes))

if v == "" {
f.logger.Debug("read an empty value", "attribute", k)
continue
}

// assume we want blank entries
Expand Down Expand Up @@ -164,14 +164,11 @@ func (f *EnvAWSFingerprint) Fingerprint(request *FingerprintRequest, response *F
return err
}
if resp != nil {
defer resp.Content.Close()

addrBytes, err := io.ReadAll(resp.Content)
addrsStr, err := readMetadataResponse(resp)
if err != nil {
return err
}

addrsStr := strings.TrimSpace(string(addrBytes))
if addrsStr == "" {
f.logger.Debug("read an empty value", "attribute", k)
} else {
Expand Down Expand Up @@ -261,7 +258,7 @@ func (f *EnvAWSFingerprint) linkSpeed(client *imds.Client) int {
return netSpeed
}

func imdsClient(ctx context.Context, endpoint string) (*imds.Client, error) {
func (f *EnvAWSFingerprint) imdsClient(ctx context.Context) (*imds.Client, error) {
client := &http.Client{
Transport: cleanhttp.DefaultTransport(),
}
Expand All @@ -274,8 +271,9 @@ func imdsClient(ctx context.Context, endpoint string) (*imds.Client, error) {
}

imdsClient := imds.NewFromConfig(cfg, func(o *imds.Options) {
if endpoint != "" {
o.Endpoint = endpoint
// endpoint should only be overridden for testing
if f.endpoint != "" {
o.Endpoint = f.endpoint
}
})
return imdsClient, nil
Expand All @@ -288,11 +286,23 @@ func isAWS(ctx context.Context, client *imds.Client) bool {
if err != nil {
return false
}

s, err := readMetadataResponse(resp)
if err != nil {
return false
}

return s != ""
}

// readImdsResponse reads and formats the IMDS response
// and most importantly, closes the io.ReadCloser
func readMetadataResponse(resp *imds.GetMetadataOutput) (string, error) {
defer resp.Content.Close()

b, err := io.ReadAll(resp.Content)
if err != nil {
return false
return "", err
}
return strings.TrimSpace(string(b)) != ""
return strings.TrimSpace(string(b)), nil
}
45 changes: 45 additions & 0 deletions client/fingerprint/env_aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ import (
"net/http/httptest"
"testing"

"github.com/aws/smithy-go"
smithyHttp "github.com/aws/smithy-go/transport/http"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -77,6 +80,45 @@ func TestEnvAWSFingerprint_aws(t *testing.T) {
}
}

func TestEnvAWSFingerprint_handleImdsError(t *testing.T) {
ci.Parallel(t)

f := NewEnvAWSFingerprint(testlog.HCLogger(t))

cases := []struct {
name string
err error
exp error
}{
{
name: "random errors return error",
err: fmt.Errorf("not http error"),
exp: fmt.Errorf("not http error"),
},
{
name: "other smithy errors return error",
err: &smithy.OperationError{},
exp: &smithy.OperationError{},
},
{
name: "http response errors correctly handled",
err: &smithyHttp.ResponseError{
Response: &smithyHttp.Response{
Response: &http.Response{
StatusCode: 404,
},
},
},
exp: nil,
},
}

for _, c := range cases {
err := f.(*EnvAWSFingerprint).handleImdsError(c.err, "some attribute")
must.Eq(t, c.exp, err)
}
}

func TestNetworkFingerprint_AWS(t *testing.T) {
ci.Parallel(t)

Expand Down Expand Up @@ -192,6 +234,9 @@ func TestNetworkFingerprint_AWS_NoNetwork(t *testing.T) {

require.Equal(t, "ami-1234", response.Attributes["platform.aws.ami-id"])

// assert the key is not present in the Attributes map if the return value was empty
require.NotContains(t, response.Attributes, "unique.platform.aws.local-ipv4")

require.Nil(t, response.NodeResources.Networks)
}

Expand Down

0 comments on commit 88a457c

Please sign in to comment.