Skip to content

Commit

Permalink
tests: mkdirall: refactor check and mkdirall helpers
Browse files Browse the repository at this point in the history
Signed-off-by: Aleksa Sarai <[email protected]>
  • Loading branch information
cyphar committed Sep 13, 2024
1 parent 350d697 commit 8484faf
Showing 1 changed file with 93 additions and 96 deletions.
189 changes: 93 additions & 96 deletions mkdir_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,88 @@ import (
"golang.org/x/sys/unix"
)

func testMkdirAll_Basic(t *testing.T, mkdirAll func(t *testing.T, root, unsafePath string, mode int) error) {
type mkdirAllFunc func(t *testing.T, root, unsafePath string, mode int) error

var mkdirAll_MkdirAll mkdirAllFunc = func(t *testing.T, root, unsafePath string, mode int) error {
// We can't check expectedPath here.
return MkdirAll(root, unsafePath, mode)
}

var mkdirAll_MkdirAllHandle mkdirAllFunc = func(t *testing.T, root, unsafePath string, mode int) error {
// Same logic as MkdirAll.
rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0)
if err != nil {
return err
}
defer rootDir.Close()
handle, err := MkdirAllHandle(rootDir, unsafePath, mode)
if err != nil {
return err
}
defer handle.Close()

// We can use SecureJoin here becuase we aren't being attacked in this
// particular test. Obviously this check is bogus for actual programs.
expectedPath, err := SecureJoin(root, unsafePath)
require.NoError(t, err)

// Now double-check that the handle is correct.
gotPath, err := procSelfFdReadlink(handle)
require.NoError(t, err, "get real path of returned handle")
assert.Equal(t, expectedPath, gotPath, "wrong final path from MkdirAllHandle")
// Also check that the f.Name() is correct while we're at it (this is
// not always guaranteed but it's better to try at least).
assert.Equal(t, expectedPath, handle.Name(), "handle from MkdirAllHandle has the wrong .Name()")
return nil
}

func checkMkdirAll(t *testing.T, partialLookupFn partialLookupFunc, root, unsafePath string, mode, expectedMode int, expectedErr error) {
rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0)
require.NoError(t, err)
defer rootDir.Close()

// Before trying to make the tree, figure out what components don't exist
// yet so we can check them later.
handle, remainingPath, err := partialLookupInRoot(rootDir, unsafePath)
handleName := "<nil>"
if handle != nil {
handleName = handle.Name()
defer handle.Close()
}
defer func() {
if t.Failed() {
t.Logf("partialLookupInRoot(%s, %s) -> (<%s>, %s, %v)", root, unsafePath, handleName, remainingPath, err)
}
}()

// This mode is different to the one set up by createTree.
const expectedMode = 0o711

// Actually make the tree.
err = mkdirAll(t, root, unsafePath, mode)
assert.ErrorIsf(t, err, expectedErr, "MkdirAll(%q, %q)", root, unsafePath)

remainingPath = filepath.Join("/", remainingPath)
for remainingPath != filepath.Dir(remainingPath) {
stat, err := fstatatFile(handle, "./"+remainingPath, unix.AT_SYMLINK_NOFOLLOW)
if expectedErr == nil {
// Check that the new components have the right mode.
if assert.NoErrorf(t, err, "unexpected error when checking new directory %q", remainingPath) {
assert.Equalf(t, uint32(unix.S_IFDIR|expectedMode), stat.Mode, "new directory %q has the wrong mode", remainingPath)
}
} else {
// Check that none of the components are directories (i.e. make
// sure that the MkdirAll was a no-op).
if err == nil {
assert.NotEqualf(t, uint32(unix.S_IFDIR), stat.Mode&unix.S_IFMT, "failed MkdirAll created a new directory at %q", remainingPath)
}
}
// Jump up a level.
remainingPath = filepath.Dir(remainingPath)
}
}

func testMkdirAll_Basic(t *testing.T, mkdirAll mkdirAllFunc) {
// We create a new tree for each test, but the template is the same.
tree := []string{
"dir a",
Expand Down Expand Up @@ -51,8 +132,9 @@ func testMkdirAll_Basic(t *testing.T, mkdirAll func(t *testing.T, root, unsafePa

withWithoutOpenat2(t, true, func(t *testing.T) {
for name, test := range map[string]struct {
unsafePath string
expectedErr error
unsafePath string
expectedErr error
expectedModeBits int
}{
"existing": {unsafePath: "a"},
"basic": {unsafePath: "a/b/c/d/e/f/g/h/i/j"},
Expand Down Expand Up @@ -99,96 +181,25 @@ func testMkdirAll_Basic(t *testing.T, mkdirAll func(t *testing.T, root, unsafePa
"loop-trailing": {unsafePath: "loop/link", expectedErr: unix.ELOOP},
"loop-basic": {unsafePath: "loop/link/foo", expectedErr: unix.ELOOP},
"loop-dotdot": {unsafePath: "loop/link/../foo", expectedErr: unix.ELOOP},
// Make sure the S_ISGID handling is correct.
"sgid-self": {unsafePath: "sgid-self/"}
} {
test := test // copy iterator
t.Run(name, func(t *testing.T) {
root := createTree(t, tree...)

rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0)
require.NoError(t, err)
defer rootDir.Close()

// Before trying to make the tree, figure out what
// components don't exist yet so we can check them later.
handle, remainingPath, err := partialLookupInRoot(rootDir, test.unsafePath)
handleName := "<nil>"
if handle != nil {
handleName = handle.Name()
defer handle.Close()
}
defer func() {
if t.Failed() {
t.Logf("partialLookupInRoot(%s, %s) -> (<%s>, %s, %v)", root, test.unsafePath, handleName, remainingPath, err)
}
}()

// This mode is different to the one set up by createTree.
const expectedMode = 0o711

// Actually make the tree.
err = mkdirAll(t, root, test.unsafePath, 0o711)
assert.ErrorIsf(t, err, test.expectedErr, "MkdirAll(%q, %q)", root, test.unsafePath)

remainingPath = filepath.Join("/", remainingPath)
for remainingPath != filepath.Dir(remainingPath) {
stat, err := fstatatFile(handle, "./"+remainingPath, unix.AT_SYMLINK_NOFOLLOW)
if test.expectedErr == nil {
// Check that the new components have the right
// mode.
if assert.NoErrorf(t, err, "unexpected error when checking new directory %q", remainingPath) {
assert.Equalf(t, uint32(unix.S_IFDIR|expectedMode), stat.Mode, "new directory %q has the wrong mode", remainingPath)
}
} else {
// Check that none of the components are
// directories (i.e. make sure that the MkdirAll
// was a no-op).
if err == nil {
assert.NotEqualf(t, uint32(unix.S_IFDIR), stat.Mode&unix.S_IFMT, "failed MkdirAll created a new directory at %q", remainingPath)
}
}
// Jump up a level.
remainingPath = filepath.Dir(remainingPath)
}
const mode = 0o711
checkMkdirAll(t, mkdirAll, root, test.unsafePath, mode, test.expectedModeBits|mode)
})
}
})
}

func TestMkdirAll_Basic(t *testing.T) {
testMkdirAll_Basic(t, func(t *testing.T, root, unsafePath string, mode int) error {
// We can't check expectedPath here.
return MkdirAll(root, unsafePath, mode)
})
testMkdirAll_Basic(t, mkdirAll_MkdirAll)
}

func TestMkdirAllHandle_Basic(t *testing.T) {
testMkdirAll_Basic(t, func(t *testing.T, root, unsafePath string, mode int) error {
// Same logic as MkdirAll.
rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0)
if err != nil {
return err
}
defer rootDir.Close()
handle, err := MkdirAllHandle(rootDir, unsafePath, mode)
if err != nil {
return err
}
defer handle.Close()

// We can use SecureJoin here becuase we aren't being attacked in this
// particular test. Obviously this check is bogus for actual programs.
expectedPath, err := SecureJoin(root, unsafePath)
require.NoError(t, err)

// Now double-check that the handle is correct.
gotPath, err := procSelfFdReadlink(handle)
require.NoError(t, err, "get real path of returned handle")
assert.Equal(t, expectedPath, gotPath, "wrong final path from MkdirAllHandle")
// Also check that the f.Name() is correct while we're at it (this is
// not always guaranteed but it's better to try at least).
assert.Equal(t, expectedPath, handle.Name(), "handle from MkdirAllHandle has the wrong .Name()")
return nil
})
testMkdirAll_Basic(t, mkdirAll_MkdirAllHandle)
}

func testMkdirAll_InvalidMode(t *testing.T, mkdirAll func(t *testing.T, root, unsafePath string, mode int) error) {
Expand Down Expand Up @@ -222,25 +233,11 @@ func testMkdirAll_InvalidMode(t *testing.T, mkdirAll func(t *testing.T, root, un
}

func TestMkdirAll_InvalidMode(t *testing.T) {
testMkdirAll_InvalidMode(t, func(t *testing.T, root, unsafePath string, mode int) error {
return MkdirAll(root, unsafePath, mode)
})
testMkdirAll_InvalidMode(t, mkdirAll_MkdirAll)
}

func TestMkdirAllHandle_InvalidMode(t *testing.T) {
testMkdirAll_InvalidMode(t, func(t *testing.T, root, unsafePath string, mode int) error {
rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0)
if err != nil {
return err
}
defer rootDir.Close()
handle, err := MkdirAllHandle(rootDir, unsafePath, mode)
if err != nil {
return err
}
_ = handle.Close()
return nil
})
testMkdirAll_InvalidMode(t, mkdirAll_MkdirAllHandle)
}

type racingMkdirMeta struct {
Expand Down

0 comments on commit 8484faf

Please sign in to comment.