diff --git a/mkdir_linux_test.go b/mkdir_linux_test.go index 9b1ddbd..5d3f059 100644 --- a/mkdir_linux_test.go +++ b/mkdir_linux_test.go @@ -19,7 +19,88 @@ import ( "golang.org/x/sys/unix" ) -func testMkdirAll_Basic(t *testing.T, mkdirAll func(t *testing.T, root, unsafePath string, mode int) error) { +type mkdirAllFunc func(t *testing.T, root, unsafePath string, mode int) error + +var mkdirAll_MkdirAll mkdirAllFunc = func(t *testing.T, root, unsafePath string, mode int) error { + // We can't check expectedPath here. + return MkdirAll(root, unsafePath, mode) +} + +var mkdirAll_MkdirAllHandle mkdirAllFunc = func(t *testing.T, root, unsafePath string, mode int) error { + // Same logic as MkdirAll. + rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) + if err != nil { + return err + } + defer rootDir.Close() + handle, err := MkdirAllHandle(rootDir, unsafePath, mode) + if err != nil { + return err + } + defer handle.Close() + + // We can use SecureJoin here becuase we aren't being attacked in this + // particular test. Obviously this check is bogus for actual programs. + expectedPath, err := SecureJoin(root, unsafePath) + require.NoError(t, err) + + // Now double-check that the handle is correct. + gotPath, err := procSelfFdReadlink(handle) + require.NoError(t, err, "get real path of returned handle") + assert.Equal(t, expectedPath, gotPath, "wrong final path from MkdirAllHandle") + // Also check that the f.Name() is correct while we're at it (this is + // not always guaranteed but it's better to try at least). + assert.Equal(t, expectedPath, handle.Name(), "handle from MkdirAllHandle has the wrong .Name()") + return nil +} + +func checkMkdirAll(t *testing.T, partialLookupFn partialLookupFunc, root, unsafePath string, mode, expectedMode int, expectedErr error) { + rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) + require.NoError(t, err) + defer rootDir.Close() + + // Before trying to make the tree, figure out what components don't exist + // yet so we can check them later. + handle, remainingPath, err := partialLookupInRoot(rootDir, unsafePath) + handleName := "" + if handle != nil { + handleName = handle.Name() + defer handle.Close() + } + defer func() { + if t.Failed() { + t.Logf("partialLookupInRoot(%s, %s) -> (<%s>, %s, %v)", root, unsafePath, handleName, remainingPath, err) + } + }() + + // This mode is different to the one set up by createTree. + const expectedMode = 0o711 + + // Actually make the tree. + err = mkdirAll(t, root, unsafePath, mode) + assert.ErrorIsf(t, err, expectedErr, "MkdirAll(%q, %q)", root, unsafePath) + + remainingPath = filepath.Join("/", remainingPath) + for remainingPath != filepath.Dir(remainingPath) { + stat, err := fstatatFile(handle, "./"+remainingPath, unix.AT_SYMLINK_NOFOLLOW) + if expectedErr == nil { + // Check that the new components have the right mode. + if assert.NoErrorf(t, err, "unexpected error when checking new directory %q", remainingPath) { + assert.Equalf(t, uint32(unix.S_IFDIR|expectedMode), stat.Mode, "new directory %q has the wrong mode", remainingPath) + } + } else { + // Check that none of the components are directories (i.e. make + // sure that the MkdirAll was a no-op). + if err == nil { + assert.NotEqualf(t, uint32(unix.S_IFDIR), stat.Mode&unix.S_IFMT, "failed MkdirAll created a new directory at %q", remainingPath) + } + } + // Jump up a level. + remainingPath = filepath.Dir(remainingPath) + } +} + +func testMkdirAll_Basic(t *testing.T, mkdirAll mkdirAllFunc) { // We create a new tree for each test, but the template is the same. tree := []string{ "dir a", @@ -51,8 +132,9 @@ func testMkdirAll_Basic(t *testing.T, mkdirAll func(t *testing.T, root, unsafePa withWithoutOpenat2(t, true, func(t *testing.T) { for name, test := range map[string]struct { - unsafePath string - expectedErr error + unsafePath string + expectedErr error + expectedModeBits int }{ "existing": {unsafePath: "a"}, "basic": {unsafePath: "a/b/c/d/e/f/g/h/i/j"}, @@ -99,96 +181,25 @@ func testMkdirAll_Basic(t *testing.T, mkdirAll func(t *testing.T, root, unsafePa "loop-trailing": {unsafePath: "loop/link", expectedErr: unix.ELOOP}, "loop-basic": {unsafePath: "loop/link/foo", expectedErr: unix.ELOOP}, "loop-dotdot": {unsafePath: "loop/link/../foo", expectedErr: unix.ELOOP}, + // Make sure the S_ISGID handling is correct. + "sgid-self": {unsafePath: "sgid-self/"} } { test := test // copy iterator t.Run(name, func(t *testing.T) { root := createTree(t, tree...) - - rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) - require.NoError(t, err) - defer rootDir.Close() - - // Before trying to make the tree, figure out what - // components don't exist yet so we can check them later. - handle, remainingPath, err := partialLookupInRoot(rootDir, test.unsafePath) - handleName := "" - if handle != nil { - handleName = handle.Name() - defer handle.Close() - } - defer func() { - if t.Failed() { - t.Logf("partialLookupInRoot(%s, %s) -> (<%s>, %s, %v)", root, test.unsafePath, handleName, remainingPath, err) - } - }() - - // This mode is different to the one set up by createTree. - const expectedMode = 0o711 - - // Actually make the tree. - err = mkdirAll(t, root, test.unsafePath, 0o711) - assert.ErrorIsf(t, err, test.expectedErr, "MkdirAll(%q, %q)", root, test.unsafePath) - - remainingPath = filepath.Join("/", remainingPath) - for remainingPath != filepath.Dir(remainingPath) { - stat, err := fstatatFile(handle, "./"+remainingPath, unix.AT_SYMLINK_NOFOLLOW) - if test.expectedErr == nil { - // Check that the new components have the right - // mode. - if assert.NoErrorf(t, err, "unexpected error when checking new directory %q", remainingPath) { - assert.Equalf(t, uint32(unix.S_IFDIR|expectedMode), stat.Mode, "new directory %q has the wrong mode", remainingPath) - } - } else { - // Check that none of the components are - // directories (i.e. make sure that the MkdirAll - // was a no-op). - if err == nil { - assert.NotEqualf(t, uint32(unix.S_IFDIR), stat.Mode&unix.S_IFMT, "failed MkdirAll created a new directory at %q", remainingPath) - } - } - // Jump up a level. - remainingPath = filepath.Dir(remainingPath) - } + const mode = 0o711 + checkMkdirAll(t, mkdirAll, root, test.unsafePath, mode, test.expectedModeBits|mode) }) } }) } func TestMkdirAll_Basic(t *testing.T) { - testMkdirAll_Basic(t, func(t *testing.T, root, unsafePath string, mode int) error { - // We can't check expectedPath here. - return MkdirAll(root, unsafePath, mode) - }) + testMkdirAll_Basic(t, mkdirAll_MkdirAll) } func TestMkdirAllHandle_Basic(t *testing.T) { - testMkdirAll_Basic(t, func(t *testing.T, root, unsafePath string, mode int) error { - // Same logic as MkdirAll. - rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) - if err != nil { - return err - } - defer rootDir.Close() - handle, err := MkdirAllHandle(rootDir, unsafePath, mode) - if err != nil { - return err - } - defer handle.Close() - - // We can use SecureJoin here becuase we aren't being attacked in this - // particular test. Obviously this check is bogus for actual programs. - expectedPath, err := SecureJoin(root, unsafePath) - require.NoError(t, err) - - // Now double-check that the handle is correct. - gotPath, err := procSelfFdReadlink(handle) - require.NoError(t, err, "get real path of returned handle") - assert.Equal(t, expectedPath, gotPath, "wrong final path from MkdirAllHandle") - // Also check that the f.Name() is correct while we're at it (this is - // not always guaranteed but it's better to try at least). - assert.Equal(t, expectedPath, handle.Name(), "handle from MkdirAllHandle has the wrong .Name()") - return nil - }) + testMkdirAll_Basic(t, mkdirAll_MkdirAllHandle) } func testMkdirAll_InvalidMode(t *testing.T, mkdirAll func(t *testing.T, root, unsafePath string, mode int) error) { @@ -222,25 +233,11 @@ func testMkdirAll_InvalidMode(t *testing.T, mkdirAll func(t *testing.T, root, un } func TestMkdirAll_InvalidMode(t *testing.T) { - testMkdirAll_InvalidMode(t, func(t *testing.T, root, unsafePath string, mode int) error { - return MkdirAll(root, unsafePath, mode) - }) + testMkdirAll_InvalidMode(t, mkdirAll_MkdirAll) } func TestMkdirAllHandle_InvalidMode(t *testing.T) { - testMkdirAll_InvalidMode(t, func(t *testing.T, root, unsafePath string, mode int) error { - rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) - if err != nil { - return err - } - defer rootDir.Close() - handle, err := MkdirAllHandle(rootDir, unsafePath, mode) - if err != nil { - return err - } - _ = handle.Close() - return nil - }) + testMkdirAll_InvalidMode(t, mkdirAll_MkdirAllHandle) } type racingMkdirMeta struct {