Skip to content

Commit

Permalink
Merge pull request #13730 from markylaing/check-trusted-follow-up
Browse files Browse the repository at this point in the history
Fix devlxd image export
  • Loading branch information
tomponline authored Jul 16, 2024
2 parents 22071eb + f58a9cd commit 499f5d1
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 80 deletions.
2 changes: 1 addition & 1 deletion lxd/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func (d *Daemon) Authenticate(w http.ResponseWriter, r *http.Request) (trusted b
}

// Devlxd unix socket credentials on main API.
if r.RemoteAddr == "@devlxd" {
if r.RemoteAddr == devlxdRemoteAddress {
return false, "", "", nil, fmt.Errorf("Main API query can't come from /dev/lxd socket")
}

Expand Down
111 changes: 83 additions & 28 deletions lxd/devlxd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -30,8 +31,12 @@ import (
"github.com/canonical/lxd/shared/ws"
)

const devlxdRemoteAddress = "@devlxd"

type hoistFunc func(f func(*Daemon, instance.Instance, http.ResponseWriter, *http.Request) response.Response, d *Daemon) func(http.ResponseWriter, *http.Request)

type devlxdHandlerFunc func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response

// DevLxdServer creates an http.Server capable of handling requests against the
// /dev/lxd Unix socket endpoint created inside containers.
func devLxdServer(d *Daemon) *http.Server {
Expand All @@ -51,10 +56,15 @@ type devLxdHandler struct {
* server side right now either, I went the simple route to avoid
* needless noise.
*/
f func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response
handlerFunc devlxdHandlerFunc
}

var devlxdConfigGet = devLxdHandler{
path: "/1.0/config",
handlerFunc: devlxdConfigGetHandler,
}

var devlxdConfigGet = devLxdHandler{"/1.0/config", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdConfigGetHandler(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(c.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), c.Type() == instancetype.VM)
}
Expand All @@ -67,9 +77,14 @@ var devlxdConfigGet = devLxdHandler{"/1.0/config", func(d *Daemon, c instance.In
}

return response.DevLxdResponse(http.StatusOK, filtered, "json", c.Type() == instancetype.VM)
}}
}

var devlxdConfigKeyGet = devLxdHandler{
path: "/1.0/config/{key}",
handlerFunc: devlxdConfigKeyGetHandler,
}

var devlxdConfigKeyGet = devLxdHandler{"/1.0/config/{key}", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdConfigKeyGetHandler(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(c.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), c.Type() == instancetype.VM)
}
Expand All @@ -89,9 +104,14 @@ var devlxdConfigKeyGet = devLxdHandler{"/1.0/config/{key}", func(d *Daemon, c in
}

return response.DevLxdResponse(http.StatusOK, value, "raw", c.Type() == instancetype.VM)
}}
}

var devlxdImageExport = devLxdHandler{
path: "/1.0/images/{fingerprint}/export",
handlerFunc: devlxdImageExportHandler,
}

var devlxdImageExport = devLxdHandler{"/1.0/images/{fingerprint}/export", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdImageExportHandler(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(c.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), c.Type() == instancetype.VM)
}
Expand All @@ -101,7 +121,7 @@ var devlxdImageExport = devLxdHandler{"/1.0/images/{fingerprint}/export", func(d
}

// Use by security checks to distinguish devlxd vs lxd APIs
r.RemoteAddr = "@devlxd"
r.RemoteAddr = devlxdRemoteAddress

resp := imageExport(d, r)

Expand All @@ -111,19 +131,29 @@ var devlxdImageExport = devLxdHandler{"/1.0/images/{fingerprint}/export", func(d
}

return response.DevLxdResponse(http.StatusOK, "", "raw", c.Type() == instancetype.VM)
}}
}

var devlxdMetadataGet = devLxdHandler{
path: "/1.0/meta-data",
handlerFunc: devlxdMetadataGetHandler,
}

var devlxdMetadataGet = devLxdHandler{"/1.0/meta-data", func(d *Daemon, inst instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdMetadataGetHandler(d *Daemon, inst instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(inst.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), inst.Type() == instancetype.VM)
}

value := inst.ExpandedConfig()["user.meta-data"]

return response.DevLxdResponse(http.StatusOK, fmt.Sprintf("#cloud-config\ninstance-id: %s\nlocal-hostname: %s\n%s", inst.CloudInitID(), inst.Name(), value), "raw", inst.Type() == instancetype.VM)
}}
}

var devlxdEventsGet = devLxdHandler{
path: "/1.0/events",
handlerFunc: devlxdEventsGetHandler,
}

var devlxdEventsGet = devLxdHandler{"/1.0/events", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdEventsGetHandler(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(c.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), c.Type() == instancetype.VM)
}
Expand Down Expand Up @@ -178,9 +208,14 @@ var devlxdEventsGet = devLxdHandler{"/1.0/events", func(d *Daemon, c instance.In
listener.Wait(r.Context())

return resp
}}
}

var devlxdAPIHandler = devLxdHandler{
path: "/1.0",
handlerFunc: devlxdAPIHandlerFunc,
}

var devlxdAPIHandler = devLxdHandler{"/1.0", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdAPIHandlerFunc(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
s := d.State()

if r.Method == "GET" {
Expand Down Expand Up @@ -236,10 +271,14 @@ var devlxdAPIHandler = devLxdHandler{"/1.0", func(d *Daemon, c instance.Instance
}

return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusMethodNotAllowed, fmt.Sprintf("method %q not allowed", r.Method)), c.Type() == instancetype.VM)
}

}}
var devlxdDevicesGet = devLxdHandler{
path: "/1.0/devices",
handlerFunc: devlxdDevicesGetHandler,
}

var devlxdDevicesGet = devLxdHandler{"/1.0/devices", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdDevicesGetHandler(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(c.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), c.Type() == instancetype.VM)
}
Expand All @@ -256,12 +295,15 @@ var devlxdDevicesGet = devLxdHandler{"/1.0/devices", func(d *Daemon, c instance.
}

return response.DevLxdResponse(http.StatusOK, c.ExpandedDevices(), "json", c.Type() == instancetype.VM)
}}
}

var handlers = []devLxdHandler{
{"/", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
return response.DevLxdResponse(http.StatusOK, []string{"/1.0"}, "json", c.Type() == instancetype.VM)
}},
{
path: "/",
handlerFunc: func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
return response.DevLxdResponse(http.StatusOK, []string{"/1.0"}, "json", c.Type() == instancetype.VM)
},
},
devlxdAPIHandler,
devlxdConfigGet,
devlxdConfigKeyGet,
Expand All @@ -276,7 +318,7 @@ func hoistReq(f func(*Daemon, instance.Instance, http.ResponseWriter, *http.Requ
conn := ucred.GetConnFromContext(r.Context())
cred, ok := pidMapper.m[conn.(*net.UnixConn)]
if !ok {
http.Error(w, pidNotInContainerErr.Error(), http.StatusInternalServerError)
http.Error(w, errPIDNotInContainer.Error(), http.StatusInternalServerError)
return
}

Expand Down Expand Up @@ -312,7 +354,7 @@ func devLxdAPI(d *Daemon, f hoistFunc) http.Handler {
m.UseEncodedPath() // Allow encoded values in path segments.

for _, handler := range handlers {
m.HandleFunc(handler.path, f(handler.f, d))
m.HandleFunc(handler.path, f(handler.handlerFunc, d))
}

return m
Expand Down Expand Up @@ -345,18 +387,27 @@ func devLxdAPI(d *Daemon, f hoistFunc) http.Handler {
*/
var pidMapper = ConnPidMapper{m: map[*net.UnixConn]*unix.Ucred{}}

// ConnPidMapper is threadsafe cache of unix connections to process IDs. We use this in hoistReq to determine
// the instance that the connection has been made from.
type ConnPidMapper struct {
m map[*net.UnixConn]*unix.Ucred
mLock sync.Mutex
}

// ConnStateHandler is used in the `ConnState` field of the devlxd http.Server so that we can cache the process ID of the
// caller when a new connection is made and delete it when the connection is closed.
func (m *ConnPidMapper) ConnStateHandler(conn net.Conn, state http.ConnState) {
unixConn := conn.(*net.UnixConn)
unixConn, _ := conn.(*net.UnixConn)
if unixConn == nil {
logger.Error("Invalid type for devlxd connection", logger.Ctx{"conn_type": fmt.Sprintf("%T", conn)})
return
}

switch state {
case http.StateNew:
cred, err := ucred.GetCred(unixConn)
if err != nil {
logger.Debugf("Error getting ucred for conn %s", err)
logger.Debug("Error getting ucred for devlxd connection", logger.Ctx{"error": err})
} else {
m.mLock.Lock()
m.m[unixConn] = cred
Expand Down Expand Up @@ -384,11 +435,11 @@ func (m *ConnPidMapper) ConnStateHandler(conn net.Conn, state http.ConnState) {
delete(m.m, unixConn)
m.mLock.Unlock()
default:
logger.Debugf("Unknown state for connection %s", state)
logger.Debug("Unknown state for devlxd connection", logger.Ctx{"state": state.String()})
}
}

var pidNotInContainerErr = fmt.Errorf("pid not in container?")
var errPIDNotInContainer = errors.New("Process ID not found in container")

func findContainerForPid(pid int32, s *state.State) (instance.Container, error) {
/*
Expand Down Expand Up @@ -437,7 +488,9 @@ func findContainerForPid(pid int32, s *state.State) (instance.Container, error)
return nil, fmt.Errorf("Instance is not container type")
}

return inst.(instance.Container), nil
// Explicitly ignore type assertion check. We've just checked that it's a container.
c, _ := inst.(instance.Container)
return c, nil
}

status, err := os.ReadFile(fmt.Sprintf("/proc/%d/status", pid))
Expand Down Expand Up @@ -490,9 +543,11 @@ func findContainerForPid(pid int32, s *state.State) (instance.Container, error)
}

if origPidNs == pidNs {
return inst.(instance.Container), nil
// Explicitly ignore type assertion check. The instance must be a container if we've found it via the process ID.
c, _ := inst.(instance.Container)
return c, nil
}
}

return nil, pidNotInContainerErr
return nil, errPIDNotInContainer
}
2 changes: 1 addition & 1 deletion lxd/devlxd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func TestHttpRequest(t *testing.T) {
t.Fatal(err)
}

if !strings.Contains(string(resp), pidNotInContainerErr.Error()) {
if !strings.Contains(string(resp), errPIDNotInContainer.Error()) {
t.Fatal("resp error not expected: ", string(resp))
}
}
Loading

0 comments on commit 499f5d1

Please sign in to comment.