Skip to content

Commit

Permalink
[API Server] Add security context to Ray Cluster (ray-project#2538)
Browse files Browse the repository at this point in the history
  • Loading branch information
han-steve authored Nov 16, 2024
1 parent ab17363 commit f3353b2
Show file tree
Hide file tree
Showing 10 changed files with 718 additions and 250 deletions.
38 changes: 32 additions & 6 deletions apiserver/pkg/model/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,12 @@ func PopulateHeadNodeSpec(spec rayv1api.HeadGroupSpec) *api.HeadGroupSpec {
headNodeSpec.EnableIngress = true
}

// Here we update environment only for a container named 'ray-head'
if container, _, ok := util.GetContainerByName(spec.Template.Spec.Containers, "ray-head"); ok && len(container.Env) > 0 {
headNodeSpec.Environment = convertEnvVariables(container.Env, true)
// Here we update environment and security context only for a container named 'ray-head'
if container, _, ok := util.GetContainerByName(spec.Template.Spec.Containers, "ray-head"); ok {
if len(container.Env) > 0 {
headNodeSpec.Environment = convertEnvVariables(container.Env, true)
}
headNodeSpec.SecurityContext = convertSecurityContext(container.SecurityContext)
}

if len(spec.Template.Spec.ServiceAccountName) > 1 {
Expand Down Expand Up @@ -291,9 +294,12 @@ func PopulateWorkerNodeSpec(specs []rayv1api.WorkerGroupSpec) []*api.WorkerGroup
workerNodeSpec.Labels = spec.Template.Labels
}

// Here we update environment only for a container named 'ray-worker'
if container, _, ok := util.GetContainerByName(spec.Template.Spec.Containers, "ray-worker"); ok && len(container.Env) > 0 {
workerNodeSpec.Environment = convertEnvVariables(container.Env, false)
// Here we update environment and security context only for a container named 'ray-worker'
if container, _, ok := util.GetContainerByName(spec.Template.Spec.Containers, "ray-worker"); ok {
if len(container.Env) > 0 {
workerNodeSpec.Environment = convertEnvVariables(container.Env, false)
}
workerNodeSpec.SecurityContext = convertSecurityContext(container.SecurityContext)
}

if len(spec.Template.Spec.ServiceAccountName) > 1 {
Expand All @@ -306,12 +312,32 @@ func PopulateWorkerNodeSpec(specs []rayv1api.WorkerGroupSpec) []*api.WorkerGroup
if spec.Template.Spec.Containers[0].ImagePullPolicy == corev1.PullAlways {
workerNodeSpec.ImagePullPolicy = "Always"
}

workerNodeSpecs = append(workerNodeSpecs, workerNodeSpec)
}

return workerNodeSpecs
}

func convertSecurityContext(securityCtx *corev1.SecurityContext) *api.SecurityContext {
if securityCtx == nil {
return nil
}
result := &api.SecurityContext{
Privileged: securityCtx.Privileged,
Capabilities: &api.Capabilities{},
}
if securityCtx.Capabilities != nil {
for _, cap := range securityCtx.Capabilities.Add {
result.Capabilities.Add = append(result.Capabilities.Add, string(cap))
}
for _, cap := range securityCtx.Capabilities.Drop {
result.Capabilities.Drop = append(result.Capabilities.Drop, string(cap))
}
}
return result
}

func convertEnvVariables(cenv []corev1.EnvVar, header bool) *api.EnvironmentVariables {
env := api.EnvironmentVariables{
Values: make(map[string]string),
Expand Down
21 changes: 21 additions & 0 deletions apiserver/pkg/model/converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ var headSpecTest = rayv1api.HeadGroupSpec{
},
},
},
SecurityContext: &corev1.SecurityContext{
Capabilities: &corev1.Capabilities{
Add: []corev1.Capability{
"SYS_PTRACE",
},
},
},
},
},
},
Expand Down Expand Up @@ -224,6 +231,13 @@ var workerSpecTest = rayv1api.WorkerGroupSpec{
Value: "1",
},
},
SecurityContext: &corev1.SecurityContext{
Capabilities: &corev1.Capabilities{
Add: []corev1.Capability{
"SYS_PTRACE",
},
},
},
},
},
},
Expand Down Expand Up @@ -487,6 +501,10 @@ func TestPopulateHeadNodeSpec(t *testing.T) {
if !reflect.DeepEqual(groupSpec.Environment, expectedHeadEnv) {
t.Errorf("failed to convert environment, got %v, expected %v", groupSpec.Environment, expectedHeadEnv)
}
// Cannot use deep equal since protobuf locks copying
if groupSpec.SecurityContext == nil || groupSpec.SecurityContext.Capabilities == nil || len(groupSpec.SecurityContext.Capabilities.Add) != 1 {
t.Errorf("failed to convert security context")
}
}

func TestPopulateWorkerNodeSpec(t *testing.T) {
Expand All @@ -507,6 +525,9 @@ func TestPopulateWorkerNodeSpec(t *testing.T) {
if !reflect.DeepEqual(groupSpec.Environment, expectedEnv) {
t.Errorf("failed to convert environment, got %v, expected %v", groupSpec.Environment, expectedEnv)
}
if groupSpec.SecurityContext == nil || groupSpec.SecurityContext.Capabilities == nil || len(groupSpec.SecurityContext.Capabilities.Add) != 1 {
t.Errorf("failed to convert security context")
}
}

func TestAutoscalerOptions(t *testing.T) {
Expand Down
26 changes: 24 additions & 2 deletions apiserver/pkg/util/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ func buildHeadPodTemplate(imageVersion string, envs *api.EnvironmentVariables, s
corev1.ResourceMemory: resource.MustParse(memory),
},
},
VolumeMounts: volMounts,
VolumeMounts: volMounts,
SecurityContext: buildSecurityContext(spec.SecurityContext),
},
},
Volumes: vols,
Expand Down Expand Up @@ -538,7 +539,8 @@ func buildWorkerPodTemplate(imageVersion string, envs *api.EnvironmentVariables,
corev1.ResourceMemory: resource.MustParse(memory),
},
},
VolumeMounts: volMounts,
VolumeMounts: volMounts,
SecurityContext: buildSecurityContext(spec.SecurityContext),
},
},
Volumes: vols,
Expand Down Expand Up @@ -803,6 +805,26 @@ func buildVols(apiVolumes []*api.Volume) ([]corev1.Volume, error) {
return vols, nil
}

// Build security context
func buildSecurityContext(securityCtx *api.SecurityContext) *corev1.SecurityContext {
if securityCtx == nil {
return nil
}
result := &corev1.SecurityContext{
Privileged: securityCtx.Privileged,
Capabilities: &corev1.Capabilities{},
}
if securityCtx.Capabilities != nil {
for _, cap := range securityCtx.Capabilities.Add {
result.Capabilities.Add = append(result.Capabilities.Add, corev1.Capability(cap))
}
for _, cap := range securityCtx.Capabilities.Drop {
result.Capabilities.Drop = append(result.Capabilities.Drop, corev1.Capability(cap))
}
}
return result
}

// Init pointer
func intPointer(value int32) *int32 {
return &value
Expand Down
23 changes: 23 additions & 0 deletions apiserver/pkg/util/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ var headGroup = api.HeadGroupSpec{
Labels: map[string]string{
"foo": "bar",
},
SecurityContext: &api.SecurityContext{
Capabilities: &api.Capabilities{
Add: []string{"SYS_PTRACE"},
},
},
}

var workerGroup = api.WorkerGroupSpec{
Expand All @@ -193,6 +198,11 @@ var workerGroup = api.WorkerGroupSpec{
Labels: map[string]string{
"foo": "bar",
},
SecurityContext: &api.SecurityContext{
Capabilities: &api.Capabilities{
Add: []string{"SYS_PTRACE"},
},
},
}

var rayCluster = api.Cluster{
Expand Down Expand Up @@ -323,6 +333,14 @@ var expectedHeadNodeEnv = []corev1.EnvVar{
},
}

var expectedSecurityContext = corev1.SecurityContext{
Capabilities: &corev1.Capabilities{
Add: []corev1.Capability{
"SYS_PTRACE",
},
},
}

func TestBuildVolumes(t *testing.T) {
targetVolume := corev1.Volume{
Name: testVolume.Name,
Expand Down Expand Up @@ -570,6 +588,10 @@ func TestBuildHeadPodTemplate(t *testing.T) {
t.Errorf("failed to convert labels, got %v, expected %v", podSpec.Labels, expectedLabels)
}

if !reflect.DeepEqual(podSpec.Spec.Containers[0].SecurityContext, &expectedSecurityContext) {
t.Errorf("failed to convert security context, got %v, expected %v", podSpec.Spec.SecurityContext, &expectedSecurityContext)
}

podSpec, err = buildHeadPodTemplate("2.4", &api.EnvironmentVariables{}, &headGroup, &template, true)
assert.Nil(t, err)
if len(podSpec.Spec.Containers[0].Ports) != 6 {
Expand Down Expand Up @@ -624,6 +646,7 @@ func TestBuilWorkerPodTemplate(t *testing.T) {
assert.True(t, containsEnvValueFrom(podSpec.Spec.Containers[0].Env, "MEMORY_LIMITS", &corev1.EnvVarSource{ResourceFieldRef: &corev1.ResourceFieldSelector{ContainerName: "ray-worker", Resource: "limits.memory"}}), "failed to propagate environment variable: MEMORY_LIMITS")
assert.True(t, containsEnvValueFrom(podSpec.Spec.Containers[0].Env, "MY_POD_NAME", &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.name"}}), "failed to propagate environment variable: MY_POD_NAME")
assert.True(t, containsEnvValueFrom(podSpec.Spec.Containers[0].Env, "MY_POD_IP", &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "status.podIP"}}), "failed to propagate environment variable: MY_POD_IP")
assert.Equal(t, &expectedSecurityContext, podSpec.Spec.Containers[0].SecurityContext, "failed to convert security context")

// Check Resources
container := podSpec.Spec.Containers[0]
Expand Down
23 changes: 23 additions & 0 deletions proto/cluster.proto
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,25 @@ message Volume {
map<string, string> items = 11; // Items used for configMap and secrets
}

// Adds and removes POSIX capabilities from running containers.
message Capabilities {
// Optional. Added capabilities
repeated string add = 1;

// Optional. Removed capabilities
repeated string drop = 2;
}

// SecurityContext holds security configuration that will be applied to a container.
// Some fields are present in both SecurityContext and PodSecurityContext. When both
// are set, the values in SecurityContext take precedence.
message SecurityContext {
// Optional. The capabilities to add/drop when running containers.
Capabilities capabilities = 1;
// Optional. Run container in privileged mode - essentially equivalent to root on the host. Default is false.
optional bool privileged = 2;
}

// Cluster HeadGroup specification
message HeadGroupSpec {
// Required. The computeTemplate of head node group
Expand Down Expand Up @@ -297,6 +316,8 @@ message HeadGroupSpec {
map<string, string> labels = 11;
// Optional image pull policy We only support Always and ifNotPresent
string imagePullPolicy = 12;
// Optional. Configure the security context for the head container for debugging etc.
SecurityContext security_context = 13;
}

message WorkerGroupSpec {
Expand Down Expand Up @@ -329,6 +350,8 @@ message WorkerGroupSpec {
map<string, string> labels = 13;
// Optional image pull policy We only support Always and ifNotPresent
string imagePullPolicy = 14;
// Optional. Configure the security context for the worker container for debugging etc.
SecurityContext security_context = 15;
}

message ClusterEvent {
Expand Down
Loading

0 comments on commit f3353b2

Please sign in to comment.