diff --git a/internal/datafs/reader.go b/internal/datafs/reader.go index 52ebefe1c..298fd9da5 100644 --- a/internal/datafs/reader.go +++ b/internal/datafs/reader.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "os" + "path" "runtime" "strings" @@ -89,7 +90,7 @@ func (d *dsReader) ReadSource(ctx context.Context, alias string, args ...string) if len(args) > 0 { arg = args[0] } - u, err := resolveURL(source.URL, arg) + u, err := resolveURL(*source.URL, arg) if err != nil { return "", nil, err } @@ -191,11 +192,9 @@ func (d *dsReader) readFileContent(ctx context.Context, u *url.URL, hdr http.Hea // resolveURL parses the relative URL rel against base, and returns the // resolved URL. Differs from url.ResolveReference in that query parameters are // added. In case of duplicates, params from rel are used. -func resolveURL(base *url.URL, rel string) (*url.URL, error) { - // if there's an opaque part, there's no resolving to do - just return the - // base URL - if base.Opaque != "" { - return base, nil +func resolveURL(base url.URL, rel string) (*url.URL, error) { + if rel == "" { + return &base, nil } // git URLs are special - they have double-slashes that separate a repo @@ -220,6 +219,28 @@ func resolveURL(base *url.URL, rel string) (*url.URL, error) { if strings.HasPrefix(rel, "//") { rel = "." + rel } + case "aws+sm": + // aws+sm URLs may be opaque, so resolution needs to be handled + // differently + if base.Opaque != "" { + // if it's opaque and we have a relative path we'll append it to + // the opaque part + if rel != "" { + base.Opaque = path.Join(base.Opaque, rel) + } + + return &base, nil + } else if base.Path == "" && !strings.HasPrefix(rel, "/") { + // if the base has no path and the relative URL doesn't start with + // a slash, we treat it as opaque + base.Opaque = rel + } + } + + // if there's still an opaque part, there's no resolving to do - just return + // the base URL + if base.Opaque != "" { + return &base, nil } relURL, err := url.Parse(rel) @@ -232,8 +253,6 @@ func resolveURL(base *url.URL, rel string) (*url.URL, error) { // correct for that. var out *url.URL switch { - case rel == "": - out = base case base.IsAbs(): out = base.ResolveReference(relURL) case base.Scheme == "" && base.Path[0] == '/': @@ -241,7 +260,7 @@ func resolveURL(base *url.URL, rel string) (*url.URL, error) { out = base.ResolveReference(relURL) out.Path = out.Path[1:] default: - out = resolveRelativeURL(base, relURL) + out = resolveRelativeURL(&base, relURL) } if base.RawQuery != "" { diff --git a/internal/datafs/reader_test.go b/internal/datafs/reader_test.go index 05472642c..740cfa53b 100644 --- a/internal/datafs/reader_test.go +++ b/internal/datafs/reader_test.go @@ -2,6 +2,7 @@ package datafs import ( "context" + "fmt" "net/http" "net/http/httptest" "net/url" @@ -21,64 +22,116 @@ import ( const osWindows = "windows" func TestResolveURL(t *testing.T) { - out, err := resolveURL(mustParseURL("http://example.com/foo.json"), "bar.json") + out, err := resolveURL(*mustParseURL("http://example.com/foo.json"), "bar.json") require.NoError(t, err) assert.Equal(t, "http://example.com/bar.json", out.String()) - out, err = resolveURL(mustParseURL("http://example.com/a/b/?n=2"), "bar.json?q=1") + out, err = resolveURL(*mustParseURL("http://example.com/a/b/?n=2"), "bar.json?q=1") require.NoError(t, err) assert.Equal(t, "http://example.com/a/b/bar.json?n=2&q=1", out.String()) - out, err = resolveURL(mustParseURL("git+file:///tmp/myrepo"), "//myfile?type=application/json") + out, err = resolveURL(*mustParseURL("git+file:///tmp/myrepo"), "//myfile?type=application/json") require.NoError(t, err) assert.Equal(t, "git+file:///tmp/myrepo//myfile?type=application/json", out.String()) - out, err = resolveURL(mustParseURL("git+file:///tmp/foo/bar/"), "//myfile?type=application/json") + out, err = resolveURL(*mustParseURL("git+file:///tmp/foo/bar/"), "//myfile?type=application/json") require.NoError(t, err) assert.Equal(t, "git+file:///tmp/foo/bar//myfile?type=application/json", out.String()) - out, err = resolveURL(mustParseURL("git+file:///tmp/myrepo/"), ".//myfile?type=application/json") + out, err = resolveURL(*mustParseURL("git+file:///tmp/myrepo/"), ".//myfile?type=application/json") require.NoError(t, err) assert.Equal(t, "git+file:///tmp/myrepo//myfile?type=application/json", out.String()) - out, err = resolveURL(mustParseURL("git+file:///tmp/repo//foo.txt"), "") + out, err = resolveURL(*mustParseURL("git+file:///tmp/repo//foo.txt"), "") require.NoError(t, err) assert.Equal(t, "git+file:///tmp/repo//foo.txt", out.String()) - out, err = resolveURL(mustParseURL("git+file:///tmp/myrepo"), ".//myfile?type=application/json") + out, err = resolveURL(*mustParseURL("git+file:///tmp/myrepo"), ".//myfile?type=application/json") require.NoError(t, err) assert.Equal(t, "git+file:///tmp/myrepo//myfile?type=application/json", out.String()) - out, err = resolveURL(mustParseURL("git+file:///tmp/myrepo//foo/?type=application/json"), "bar/myfile") + out, err = resolveURL(*mustParseURL("git+file:///tmp/myrepo//foo/?type=application/json"), "bar/myfile") require.NoError(t, err) // note that the '/' in the query string is encoded to %2F - that's OK assert.Equal(t, "git+file:///tmp/myrepo//foo/bar/myfile?type=application%2Fjson", out.String()) // both base and relative may not contain "//" - _, err = resolveURL(mustParseURL("git+ssh://git@example.com/foo//bar"), ".//myfile") + _, err = resolveURL(*mustParseURL("git+ssh://git@example.com/foo//bar"), ".//myfile") require.Error(t, err) - _, err = resolveURL(mustParseURL("git+ssh://git@example.com/foo//bar"), "baz//myfile") + _, err = resolveURL(*mustParseURL("git+ssh://git@example.com/foo//bar"), "baz//myfile") require.Error(t, err) // relative base URLs must remain relative - out, err = resolveURL(mustParseURL("tmp/foo.json"), "") + out, err = resolveURL(*mustParseURL("tmp/foo.json"), "") require.NoError(t, err) assert.Equal(t, "tmp/foo.json", out.String()) // relative implicit file URLs without volume or scheme are OK - out, err = resolveURL(mustParseURL("/tmp/"), "foo.json") + out, err = resolveURL(*mustParseURL("/tmp/"), "foo.json") require.NoError(t, err) assert.Equal(t, "tmp/foo.json", out.String()) // relative base URLs in parent directories are OK - out, err = resolveURL(mustParseURL("../../tmp/foo.json"), "") + out, err = resolveURL(*mustParseURL("../../tmp/foo.json"), "") require.NoError(t, err) assert.Equal(t, "../../tmp/foo.json", out.String()) - out, err = resolveURL(mustParseURL("../../tmp/"), "sub/foo.json") + out, err = resolveURL(*mustParseURL("../../tmp/"), "sub/foo.json") require.NoError(t, err) assert.Equal(t, "../../tmp/sub/foo.json", out.String()) + + t.Run("aws+sm", func(t *testing.T) { + out, err = resolveURL(*mustParseURL("aws+sm:"), "foo") + require.NoError(t, err) + assert.Equal(t, "aws+sm:foo", out.String()) + + out, err = resolveURL(*mustParseURL("aws+sm:foo/"), "bar") + require.NoError(t, err) + assert.Equal(t, "aws+sm:foo/bar", out.String()) + + out, err = resolveURL(*mustParseURL("aws+sm:"), "/foo") + require.NoError(t, err) + assert.Equal(t, "aws+sm:///foo", out.String()) + + out, err = resolveURL(*mustParseURL("aws+sm:///foo/"), "bar") + require.NoError(t, err) + assert.Equal(t, "aws+sm:///foo/bar", out.String()) + }) +} + +func BenchmarkResolveURL(b *testing.B) { + args := []struct { + url url.URL + rel string + }{ + {*mustParseURL("http://example.com/foo.json"), "bar.json"}, + {*mustParseURL("http://example.com/a/b/?n=2"), "bar.json?q=1"}, + {*mustParseURL("git+file:///tmp/myrepo"), "//myfile?type=application/json"}, + {*mustParseURL("git+file:///tmp/myrepo2"), ".//myfile?type=application/json"}, + {*mustParseURL("git+file:///tmp/foo/bar/"), "//myfile?type=application/json"}, + {*mustParseURL("git+file:///tmp/myrepo/"), ".//myfile?type=application/json"}, + {*mustParseURL("git+file:///tmp/repo//foo.txt"), ""}, + {*mustParseURL("git+file:///tmp/myrepo//foo/?type=application/json"), "bar/myfile"}, + {*mustParseURL("tmp/foo.json"), ""}, + {*mustParseURL("/tmp/"), "foo.json"}, + {*mustParseURL("../../tmp/foo.json"), ""}, + {*mustParseURL("../../tmp/"), "sub/foo.json"}, + {*mustParseURL("aws+sm:"), "foo"}, + {*mustParseURL("aws+sm:"), "/foo"}, + {*mustParseURL("aws+sm:foo"), "bar"}, + {*mustParseURL("aws+sm:///foo"), "bar"}, + } + + b.ResetTimer() + + for _, a := range args { + b.Run(fmt.Sprintf("base=%s_rel=%s", &a.url, a.rel), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = resolveURL(a.url, a.rel) + } + }) + } } func TestReadFileContent(t *testing.T) {