Skip to content

Commit 283cc30

Browse files
committed
Allow mps root to be specified
This change allows the MPS root on the host to be specified and uses /run/nvidia/mps by default. Signed-off-by: Evan Lezar <[email protected]>
1 parent 35c1393 commit 283cc30

File tree

9 files changed

+126
-39
lines changed

9 files changed

+126
-39
lines changed

api/config/v1/flags.go

+3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ type Flags struct {
5757
type CommandLineFlags struct {
5858
MigStrategy *string `json:"migStrategy" yaml:"migStrategy"`
5959
FailOnInitError *bool `json:"failOnInitError" yaml:"failOnInitError"`
60+
MpsRoot *string `json:"mpsRoot,omitempty" yaml:"mpsRoot,omitempty"`
6061
NvidiaDriverRoot *string `json:"nvidiaDriverRoot,omitempty" yaml:"nvidiaDriverRoot,omitempty"`
6162
GDSEnabled *bool `json:"gdsEnabled" yaml:"gdsEnabled"`
6263
MOFEDEnabled *bool `json:"mofedEnabled" yaml:"mofedEnabled"`
@@ -116,6 +117,8 @@ func (f *Flags) UpdateFromCLIFlags(c *cli.Context, flags []cli.Flag) {
116117
updateFromCLIFlag(&f.MigStrategy, c, n)
117118
case "fail-on-init-error":
118119
updateFromCLIFlag(&f.FailOnInitError, c, n)
120+
case "mps-root":
121+
updateFromCLIFlag(&f.MpsRoot, c, n)
119122
case "nvidia-driver-root":
120123
updateFromCLIFlag(&f.NvidiaDriverRoot, c, n)
121124
case "gds-enabled":

cmd/mps-control-daemon/mps/daemon.go

+14-15
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"io"
2323
"os"
2424
"os/exec"
25-
"path/filepath"
2625

2726
"k8s.io/klog/v2"
2827

@@ -46,14 +45,14 @@ type Daemon struct {
4645
rm rm.ResourceManager
4746
// root represents the root at which the files and folders controlled by the
4847
// daemon are created. These include the log and pipe directories.
49-
root string
48+
root Root
5049
}
5150

5251
// NewDaemon creates an MPS daemon instance.
53-
func NewDaemon(rm rm.ResourceManager) *Daemon {
52+
func NewDaemon(rm rm.ResourceManager, root Root) *Daemon {
5453
return &Daemon{
5554
rm: rm,
56-
root: "/mps",
55+
root: root,
5756
}
5857
}
5958

@@ -77,8 +76,8 @@ func (e envvars) toSlice() []string {
7776
// TODO: Set CUDA_VISIBLE_DEVICES to include only the devices for this resource type.
7877
func (d *Daemon) Envvars() envvars {
7978
return map[string]string{
80-
"CUDA_MPS_PIPE_DIRECTORY": d.pipeDir(),
81-
"CUDA_MPS_LOG_DIRECTORY": d.logDir(),
79+
"CUDA_MPS_PIPE_DIRECTORY": d.PipeDir(),
80+
"CUDA_MPS_LOG_DIRECTORY": d.LogDir(),
8281
}
8382
}
8483

@@ -90,12 +89,12 @@ func (d *Daemon) Start() error {
9089

9190
klog.InfoS("Staring MPS daemon", "resource", d.rm.Resource())
9291

93-
pipeDir := d.pipeDir()
92+
pipeDir := d.PipeDir()
9493
if err := os.MkdirAll(pipeDir, 0755); err != nil {
9594
return fmt.Errorf("error creating directory %v: %w", pipeDir, err)
9695
}
9796

98-
logDir := d.logDir()
97+
logDir := d.LogDir()
9998
if err := os.MkdirAll(logDir, 0755); err != nil {
10099
return fmt.Errorf("error creating directory %v: %w", logDir, err)
101100
}
@@ -147,20 +146,20 @@ func (d *Daemon) Stop() error {
147146
return nil
148147
}
149148

150-
func (d *Daemon) resourceRoot() string {
151-
return filepath.Join(d.root, string(d.rm.Resource()))
149+
func (d *Daemon) LogDir() string {
150+
return d.root.LogDir(d.rm.Resource())
152151
}
153152

154-
func (d *Daemon) pipeDir() string {
155-
return filepath.Join(d.resourceRoot(), "pipe")
153+
func (d *Daemon) PipeDir() string {
154+
return d.root.PipeDir(d.rm.Resource())
156155
}
157156

158-
func (d *Daemon) logDir() string {
159-
return filepath.Join(d.resourceRoot(), "log")
157+
func (d *Daemon) ShmDir() string {
158+
return "/dev/shm"
160159
}
161160

162161
func (d *Daemon) startedFile() string {
163-
return filepath.Join(d.resourceRoot(), ".started")
162+
return d.root.startedFile(d.rm.Resource())
164163
}
165164

166165
// AssertHealthy checks that the MPS control daemon is healthy.

cmd/mps-control-daemon/mps/manager.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func (m *manager) Daemons() ([]*Daemon, error) {
8585
klog.InfoS("Resource is not shared", "resource", "resource", resourceManager.Resource())
8686
continue
8787
}
88-
daemon := NewDaemon(resourceManager)
88+
daemon := NewDaemon(resourceManager, ContainerRoot)
8989
daemons = append(daemons, daemon)
9090
}
9191

cmd/mps-control-daemon/mps/root.go

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/**
2+
# Copyright 2024 NVIDIA CORPORATION
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package mps
18+
19+
import (
20+
"path/filepath"
21+
22+
spec "github.com/NVIDIA/k8s-device-plugin/api/config/v1"
23+
)
24+
25+
const (
26+
ContainerRoot = Root("/mps")
27+
)
28+
29+
// Root represents an MPS root.
30+
// This is where per-resource pipe and log directories are created.
31+
// For containerised applications the host root is typically mounted to /mps in the container.
32+
type Root string
33+
34+
// LogDir returns the per-resource pipe dir for the specified root.
35+
func (r Root) LogDir(resourceName spec.ResourceName) string {
36+
return r.Path(string(resourceName), "log")
37+
}
38+
39+
// PipeDir returns the per-resource pipe dir for the specified root.
40+
func (r Root) PipeDir(resourceName spec.ResourceName) string {
41+
return r.Path(string(resourceName), "pipe")
42+
}
43+
44+
// ShmDir returns the shm dir associated with the root.
45+
// Note that the shm dir is the same for all resources.
46+
func (r Root) ShmDir(resourceName spec.ResourceName) string {
47+
return r.Path("shm")
48+
}
49+
50+
// startedFile returns the per-resource .started file name for the specified root.
51+
func (r Root) startedFile(resourceName spec.ResourceName) string {
52+
return r.Path(string(resourceName), ".started")
53+
}
54+
55+
// Path returns a path relative to the MPS root.
56+
func (r Root) Path(parts ...string) string {
57+
pathparts := append([]string{string(r)}, parts...)
58+
return filepath.Join(pathparts...)
59+
}

cmd/nvidia-device-plugin/main.go

+8
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ func main() {
120120
Usage: "the path where the NVIDIA driver root is mounted in the container; used for generating CDI specifications",
121121
EnvVars: []string{"CONTAINER_DRIVER_ROOT"},
122122
},
123+
&cli.StringFlag{
124+
Name: "mps-root",
125+
Usage: "the path on the host where MPS-specific mounts and files are created by the MPS control daemon manager",
126+
EnvVars: []string{"MPS_ROOT"},
127+
},
123128
}
124129

125130
err := c.Run(os.Args)
@@ -148,6 +153,9 @@ func validateFlags(config *spec.Config) error {
148153
if *config.Flags.MigStrategy == spec.MigStrategyMixed {
149154
return fmt.Errorf("using --mig-strategy=mixed is not supported with MPS")
150155
}
156+
if config.Flags.MpsRoot == nil || *config.Flags.MpsRoot == "" {
157+
return fmt.Errorf("using MPS requires --mps-root to be specified")
158+
}
151159
}
152160

153161
return nil

deployments/helm/nvidia-device-plugin/templates/daemonset-device-plugin.yml

+4-3
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ spec:
136136
name: nvidia-device-plugin-ctr
137137
command: ["nvidia-device-plugin"]
138138
env:
139+
- name: MPS_ROOT
140+
value: "{{ .Values.mps.root }}"
139141
{{- if typeIs "string" .Values.migStrategy }}
140142
- name: MIG_STRATEGY
141143
value: "{{ .Values.migStrategy }}"
@@ -215,12 +217,11 @@ spec:
215217
path: /var/lib/kubelet/device-plugins
216218
- name: mps-root
217219
hostPath:
218-
# TODO: This should be /var/run/nvidia/mps
219-
path: /var/lib/kubelet/device-plugins/mps
220+
path: {{ .Values.mps.root }}
220221
type: DirectoryOrCreate
221222
- name: mps-shm
222223
hostPath:
223-
path: /var/lib/kubelet/device-plugins/mps/shm
224+
path: {{ .Values.mps.root }}/shm
224225
{{- if typeIs "string" .Values.nvidiaDriverRoot }}
225226
- name: driver-root
226227
hostPath:

deployments/helm/nvidia-device-plugin/templates/daemonset-mps-control-daemon.yml

+2-3
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,11 @@ spec:
194194
volumes:
195195
- name: mps-root
196196
hostPath:
197-
# TODO: This should be /var/run/nvidia/mps
198-
path: /var/lib/kubelet/device-plugins/mps
197+
path: {{ .Values.mps.root }}
199198
type: DirectoryOrCreate
200199
- name: mps-shm
201200
hostPath:
202-
path: /var/lib/kubelet/device-plugins/mps/shm
201+
path: {{ .Values.mps.root }}/shm
203202
{{- if eq $hasConfigMap "true" }}
204203
- name: available-configs
205204
configMap:

deployments/helm/nvidia-device-plugin/values.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,10 @@ nfd:
145145
- "0302"
146146
deviceLabelFields:
147147
- vendor
148+
149+
mps:
150+
# root specifies the location where files and folders for managing MPS will
151+
# be created. This includes a daemon-specific /dev/shm and pipe and log
152+
# directories.
153+
# Pipe directories will be created at {{ mps.root }}/{{ .ResourceName }}
154+
root: "/run/nvidia/mps"

internal/plugin/server.go

+28-17
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ type NvidiaDevicePlugin struct {
6363
server *grpc.Server
6464
health chan *rm.Device
6565
stop chan interface{}
66+
67+
mpsDaemon *mps.Daemon
68+
mpsHostRoot mps.Root
6669
}
6770

6871
// NewNvidiaDevicePlugin returns an initialized NvidiaDevicePlugin
@@ -74,6 +77,13 @@ func NewNvidiaDevicePlugin(config *spec.Config, resourceManager rm.ResourceManag
7477
pluginName := "nvidia-" + name
7578
pluginPath := filepath.Join(pluginapi.DevicePluginPath, pluginName)
7679

80+
var mpsDaemon *mps.Daemon
81+
var mpsHostRoot mps.Root
82+
if config.Sharing.SharingStrategy() != spec.SharingStrategyMPS {
83+
mpsDaemon = mps.NewDaemon(resourceManager, mps.ContainerRoot)
84+
mpsHostRoot = mps.Root(*config.Flags.CommandLineFlags.MpsRoot)
85+
}
86+
7787
return &NvidiaDevicePlugin{
7888
rm: resourceManager,
7989
config: config,
@@ -83,6 +93,9 @@ func NewNvidiaDevicePlugin(config *spec.Config, resourceManager rm.ResourceManag
8393
cdiHandler: cdiHandler,
8494
cdiAnnotationPrefix: *config.Flags.Plugin.CDIAnnotationPrefix,
8595

96+
mpsDaemon: mpsDaemon,
97+
mpsHostRoot: mpsHostRoot,
98+
8699
// These will be reinitialized every
87100
// time the plugin server is restarted.
88101
server: nil,
@@ -148,11 +161,12 @@ func (plugin *NvidiaDevicePlugin) waitForMPSDaemon() error {
148161
if plugin.config.Sharing.SharingStrategy() != spec.SharingStrategyMPS {
149162
return nil
150163
}
151-
// TODO: Check the started file here.
164+
// TODO: Check the .ready file here.
152165
// TODO: Have some retry strategy here.
153-
if err := mps.NewDaemon(plugin.rm).AssertHealthy(); err != nil {
166+
if err := plugin.mpsDaemon.AssertHealthy(); err != nil {
154167
return fmt.Errorf("error checking MPS daemon health: %w", err)
155168
}
169+
klog.InfoS("MPS daemon is healthy", "resource", plugin.rm.Resource())
156170
return nil
157171
}
158172

@@ -329,7 +343,6 @@ func (plugin *NvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*plu
329343
response := &pluginapi.ContainerAllocateResponse{
330344
Envs: make(map[string]string),
331345
}
332-
333346
if plugin.deviceListStrategies.IsCDIEnabled() {
334347
responseID := uuid.New().String()
335348
if err := plugin.updateResponseForCDI(response, responseID, deviceIDs...); err != nil {
@@ -361,26 +374,24 @@ func (plugin *NvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*plu
361374
// This includes per-resource pipe and log directories as well as a global daemon-specific shm
362375
// and assumes that an MPS control daemon has already been started.
363376
func (plugin NvidiaDevicePlugin) updateResponseForMPS(response *pluginapi.ContainerAllocateResponse) {
364-
pipeDir := filepath.Join("/mps", string(plugin.rm.Resource()), "pipe")
365-
response.Envs["CUDA_MPS_PIPE_DIRECTORY"] = pipeDir
377+
// TODO: We should check that the deviceIDs are shared using MPS.
378+
for k, v := range plugin.mpsDaemon.Envvars() {
379+
response.Envs[k] = v
380+
}
381+
382+
resourceName := plugin.rm.Resource()
366383
response.Mounts = append(response.Mounts,
367384
&pluginapi.Mount{
368-
ContainerPath: pipeDir,
369-
HostPath: filepath.Join("/var/lib/kubelet/device-plugins", pipeDir),
385+
ContainerPath: plugin.mpsDaemon.PipeDir(),
386+
HostPath: plugin.mpsHostRoot.PipeDir(resourceName),
370387
},
371-
)
372-
logDir := filepath.Join("/mps", string(plugin.rm.Resource()), "log")
373-
response.Envs["CUDA_MPS_LOG_DIRECTORY"] = logDir
374-
response.Mounts = append(response.Mounts,
375388
&pluginapi.Mount{
376-
ContainerPath: logDir,
377-
HostPath: filepath.Join("/var/lib/kubelet/device-plugins", logDir),
389+
ContainerPath: plugin.mpsDaemon.PipeDir(),
390+
HostPath: plugin.mpsHostRoot.LogDir(resourceName),
378391
},
379-
)
380-
response.Mounts = append(response.Mounts,
381392
&pluginapi.Mount{
382-
ContainerPath: "/dev/shm",
383-
HostPath: "/var/lib/kubelet/device-plugins/mps/shm",
393+
ContainerPath: plugin.mpsDaemon.ShmDir(),
394+
HostPath: plugin.mpsHostRoot.ShmDir(resourceName),
384395
},
385396
)
386397
}

0 commit comments

Comments
 (0)