Skip to content

Commit

Permalink
daemon: make ucrednetGet() return *ucrednet
Browse files Browse the repository at this point in the history
  • Loading branch information
thp-canonical committed Feb 13, 2024
1 parent 0a1127e commit c658437
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 55 deletions.
4 changes: 2 additions & 2 deletions internals/daemon/api_notices.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ func v1GetNotices(c *Command, r *http.Request, _ *UserState) Response {

// Get the UID of the request. If the UID is not known, return an error.
func uidFromRequest(r *http.Request) (uint32, error) {
_, uid, _, err := ucrednetGet(r.RemoteAddr)
ucred, err := ucrednetGet(r.RemoteAddr)
if err != nil {
return 0, fmt.Errorf("could not parse request UID")
}
return uid, nil
return ucred.Uid, nil
}

// Construct the user IDs filter which will be passed to state.Notices.
Expand Down
9 changes: 3 additions & 6 deletions internals/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,15 @@ func (c *Command) canAccess(r *http.Request, user *UserState) accessResult {

// isUser means we have a UID for the request
isUser := false
pid, uid, socket, err := ucrednetGet(r.RemoteAddr)
ucred, err := ucrednetGet(r.RemoteAddr)
if err == nil {
isUser = true
} else if err != errNoID {
logger.Noticef("Cannot parse UID from remote address %q: %s", r.RemoteAddr, err)
return accessForbidden
}

isUntrusted := (socket == c.d.untrustedSocketPath)

_ = pid
_ = uid
isUntrusted := (ucred != nil && ucred.Socket == c.d.untrustedSocketPath)

if isUntrusted {
if c.UntrustedOK {
Expand All @@ -203,7 +200,7 @@ func (c *Command) canAccess(r *http.Request, user *UserState) accessResult {
return accessUnauthorized
}

if uid == 0 || sys.UserID(uid) == sysGetuid() {
if ucred.Uid == 0 || sys.UserID(ucred.Uid) == sysGetuid() {
// Superuser and process owner can do anything.
return accessOK
}
Expand Down
34 changes: 18 additions & 16 deletions internals/daemon/ucrednet.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,40 +35,42 @@ const (

var raddrRegexp = regexp.MustCompile(`^pid=(\d+);uid=(\d+);socket=([^;]*);$`)

func ucrednetGet(remoteAddr string) (pid int32, uid uint32, socket string, err error) {
func ucrednetGet(remoteAddr string) (*ucrednet, error) {
// NOTE treat remoteAddr at one point included a user-controlled
// string. In case that happens again by accident, treat it as tainted,
// and be very suspicious of it.
pid = ucrednetNoProcess
uid = ucrednetNobody
u := &ucrednet{
Pid: ucrednetNoProcess,
Uid: ucrednetNobody,
}
subs := raddrRegexp.FindStringSubmatch(remoteAddr)
if subs != nil {
if v, err := strconv.ParseInt(subs[1], 10, 32); err == nil {
pid = int32(v)
u.Pid = int32(v)
}
if v, err := strconv.ParseUint(subs[2], 10, 32); err == nil {
uid = uint32(v)
u.Uid = uint32(v)
}
socket = subs[3]
u.Socket = subs[3]
}
if pid == ucrednetNoProcess || uid == ucrednetNobody {
err = errNoID
if u.Pid == ucrednetNoProcess || u.Uid == ucrednetNobody {
return nil, errNoID
}

return pid, uid, socket, err
return u, nil
}

type ucrednet struct {
pid int32
uid uint32
socket string
Pid int32
Uid uint32
Socket string
}

func (un *ucrednet) String() string {
if un == nil {
return "pid=;uid=;socket=;"
}
return fmt.Sprintf("pid=%d;uid=%d;socket=%s;", un.pid, un.uid, un.socket)
return fmt.Sprintf("pid=%d;uid=%d;socket=%s;", un.Pid, un.Uid, un.Socket)
}

type ucrednetAddr struct {
Expand Down Expand Up @@ -127,9 +129,9 @@ func (wl *ucrednetListener) Accept() (net.Conn, error) {
return nil, ucredErr
}
unet = &ucrednet{
pid: ucred.Pid,
uid: ucred.Uid,
socket: ucon.LocalAddr().String(),
Pid: ucred.Pid,
Uid: ucred.Uid,
Socket: ucon.LocalAddr().String(),
}
}

Expand Down
55 changes: 24 additions & 31 deletions internals/daemon/ucrednet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ func (s *ucrednetSuite) TestAcceptConnRemoteAddrString(c *check.C) {

remoteAddr := conn.RemoteAddr().String()
c.Check(remoteAddr, check.Matches, "pid=100;uid=42;.*")
pid, uid, _, err := ucrednetGet(remoteAddr)
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, uint32(42))
u, err := ucrednetGet(remoteAddr)
c.Check(u.Pid, check.Equals, int32(100))
c.Check(u.Uid, check.Equals, uint32(42))
c.Check(err, check.IsNil)
}

Expand All @@ -96,10 +96,9 @@ func (s *ucrednetSuite) TestNonUnix(c *check.C) {

remoteAddr := conn.RemoteAddr().String()
c.Check(remoteAddr, check.Matches, "pid=;uid=;.*")
pid, uid, _, err := ucrednetGet(remoteAddr)
c.Check(pid, check.Equals, ucrednetNoProcess)
c.Check(uid, check.Equals, ucrednetNobody)
u, err := ucrednetGet(remoteAddr)
c.Check(err, check.Equals, errNoID)
c.Check(u, check.IsNil)
}

func (s *ucrednetSuite) TestAcceptErrors(c *check.C) {
Expand Down Expand Up @@ -152,53 +151,47 @@ func (s *ucrednetSuite) TestIdempotentClose(c *check.C) {
}

func (s *ucrednetSuite) TestGetNoUid(c *check.C) {
pid, uid, _, err := ucrednetGet("pid=100;uid=;socket=;")
u, err := ucrednetGet("pid=100;uid=;socket=;")
c.Check(err, check.Equals, errNoID)
c.Check(pid, check.Equals, ucrednetNoProcess)
c.Check(uid, check.Equals, ucrednetNobody)
c.Check(u, check.IsNil)
}

func (s *ucrednetSuite) TestGetBadUid(c *check.C) {
pid, uid, _, err := ucrednetGet("pid=100;uid=4294967296;socket=;")
c.Check(err, check.NotNil)
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, ucrednetNobody)
u, err := ucrednetGet("pid=100;uid=4294967296;socket=;")
c.Check(err, check.Equals, errNoID)
c.Check(u, check.IsNil)
}

func (s *ucrednetSuite) TestGetNonUcrednet(c *check.C) {
pid, uid, _, err := ucrednetGet("hello")
u, err := ucrednetGet("hello")
c.Check(err, check.Equals, errNoID)
c.Check(pid, check.Equals, ucrednetNoProcess)
c.Check(uid, check.Equals, ucrednetNobody)
c.Check(u, check.IsNil)
}

func (s *ucrednetSuite) TestGetNothing(c *check.C) {
pid, uid, _, err := ucrednetGet("")
u, err := ucrednetGet("")
c.Check(err, check.Equals, errNoID)
c.Check(pid, check.Equals, ucrednetNoProcess)
c.Check(uid, check.Equals, ucrednetNobody)
c.Check(u, check.IsNil)
}

func (s *ucrednetSuite) TestGet(c *check.C) {
pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/.pebble.socket;")
u, err := ucrednetGet("pid=100;uid=42;socket=/run/.pebble.socket;")
c.Check(err, check.IsNil)
c.Check(pid, check.Equals, int32(100))
c.Check(uid, check.Equals, uint32(42))
c.Check(socket, check.Equals, "/run/.pebble.socket")
c.Check(u.Pid, check.Equals, int32(100))
c.Check(u.Uid, check.Equals, uint32(42))
c.Check(u.Socket, check.Equals, "/run/.pebble.socket")
}

func (s *ucrednetSuite) TestGetSneak(c *check.C) {
pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/.pebble.socket;pid=0;uid=0;socket=/tmp/my.socket")
u, err := ucrednetGet("pid=100;uid=42;socket=/run/.pebble.socket;pid=0;uid=0;socket=/tmp/my.socket")
c.Check(err, check.Equals, errNoID)
c.Check(pid, check.Equals, ucrednetNoProcess)
c.Check(uid, check.Equals, ucrednetNobody)
c.Check(socket, check.Equals, "")
c.Check(u, check.IsNil)
}

func (s *ucrednetSuite) TestGetWithZeroPid(c *check.C) {
pid, uid, socket, err := ucrednetGet("pid=0;uid=42;socket=/run/.pebble.socket;")
u, err := ucrednetGet("pid=0;uid=42;socket=/run/.pebble.socket;")
c.Check(err, check.IsNil)
c.Check(pid, check.Equals, int32(0))
c.Check(uid, check.Equals, uint32(42))
c.Check(socket, check.Equals, "/run/.pebble.socket")
c.Check(u.Pid, check.Equals, int32(0))
c.Check(u.Uid, check.Equals, uint32(42))
c.Check(u.Socket, check.Equals, "/run/.pebble.socket")
}

0 comments on commit c658437

Please sign in to comment.