From d2918025aa143440779654e3653e51f14e9b41c5 Mon Sep 17 00:00:00 2001
From: Jean-Francois Roy <jeroy@nvidia.com>
Date: Wed, 18 Sep 2024 08:49:38 -0700
Subject: [PATCH] use the new Go wrapper program

This patch modifies the the container toolkit installer, used by the
GPU operator, to use the new Go wrapper program.
---
 tools/container/toolkit/executable.go      | 140 +++++++++++----------
 tools/container/toolkit/executable_test.go | 139 +++++++++++---------
 tools/container/toolkit/runtime.go         |  21 +---
 tools/container/toolkit/runtime_test.go    |  37 ++----
 tools/container/toolkit/toolkit.go         |  14 +--
 5 files changed, 171 insertions(+), 180 deletions(-)

diff --git a/tools/container/toolkit/executable.go b/tools/container/toolkit/executable.go
index 0d59e3756..91374bbac 100644
--- a/tools/container/toolkit/executable.go
+++ b/tools/container/toolkit/executable.go
@@ -18,71 +18,78 @@ package main
 
 import (
 	"fmt"
-	"io"
 	"os"
 	"path/filepath"
 	"sort"
-	"strings"
 
 	log "github.com/sirupsen/logrus"
 )
 
 type executableTarget struct {
-	dotfileName string
 	wrapperName string
 }
 
 type executable struct {
-	source   string
-	target   executableTarget
-	env      map[string]string
-	preLines []string
-	argLines []string
+	source string
+	target executableTarget
+	argv   []string
+	envm   map[string]string
 }
 
 // install installs an executable component of the NVIDIA container toolkit. The source executable
 // is copied to a `.real` file and a wapper is created to set up the environment as required.
 func (e executable) install(destFolder string) (string, error) {
+	if destFolder == "" {
+		return "", fmt.Errorf("destination folder must be specified")
+	}
+	if e.source == "" {
+		return "", fmt.Errorf("source executable must be specified")
+	}
 	log.Infof("Installing executable '%v' to %v", e.source, destFolder)
-
-	dotfileName := e.dotfileName()
-
-	installedDotfileName, err := installFileToFolderWithName(destFolder, dotfileName, e.source)
+	dotRealFilename := e.dotRealFilename()
+	dotRealPath, err := installFileToFolderWithName(destFolder, dotRealFilename, e.source)
 	if err != nil {
-		return "", fmt.Errorf("error installing file '%v' as '%v': %v", e.source, dotfileName, err)
+		return "", fmt.Errorf("error installing file '%v' as '%v': %v", e.source, dotRealFilename, err)
 	}
-	log.Infof("Installed '%v'", installedDotfileName)
+	log.Infof("Installed '%v'", dotRealPath)
 
-	wrapperFilename, err := e.installWrapper(destFolder, installedDotfileName)
+	wrapperPath, err := e.installWrapper(destFolder)
 	if err != nil {
-		return "", fmt.Errorf("error wrapping '%v': %v", installedDotfileName, err)
+		return "", fmt.Errorf("error installing wrapper: %v", err)
 	}
-	log.Infof("Installed wrapper '%v'", wrapperFilename)
-
-	return wrapperFilename, nil
+	log.Infof("Installed wrapper '%v'", wrapperPath)
+	return wrapperPath, nil
 }
 
-func (e executable) dotfileName() string {
-	return e.target.dotfileName
+func (e executable) dotRealFilename() string {
+	return e.wrapperName() + ".real"
 }
 
 func (e executable) wrapperName() string {
+	if e.target.wrapperName == "" {
+		return filepath.Base(e.source)
+	}
 	return e.target.wrapperName
 }
 
-func (e executable) installWrapper(destFolder string, dotfileName string) (string, error) {
-	wrapperPath := filepath.Join(destFolder, e.wrapperName())
-	wrapper, err := os.Create(wrapperPath)
+func (e executable) installWrapper(destFolder string) (string, error) {
+	currentExe, err := os.Executable()
 	if err != nil {
-		return "", fmt.Errorf("error creating executable wrapper: %v", err)
+		return "", fmt.Errorf("error getting current executable: %v", err)
 	}
-	defer wrapper.Close()
-
-	err = e.writeWrapperTo(wrapper, destFolder, dotfileName)
+	src := filepath.Join(filepath.Dir(currentExe), "wrapper")
+	wrapperPath, err := installFileToFolderWithName(destFolder, e.wrapperName(), src)
 	if err != nil {
-		return "", fmt.Errorf("error writing wrapper contents: %v", err)
+		return "", fmt.Errorf("error installing wrapper program: %v", err)
+	}
+	err = e.writeWrapperArgv(wrapperPath, destFolder)
+	if err != nil {
+		return "", fmt.Errorf("error writing wrapper argv: %v", err)
+	}
+	err = e.writeWrapperEnvv(wrapperPath, destFolder)
+	if err != nil {
+		return "", fmt.Errorf("error writing wrapper envv: %v", err)
 	}
-
 	err = ensureExecutable(wrapperPath)
 	if err != nil {
 		return "", fmt.Errorf("error making wrapper executable: %v", err)
@@ -90,51 +97,54 @@ func (e executable) installWrapper(destFolder string, dotfileName string) (strin
 	return wrapperPath, nil
 }
 
-func (e executable) writeWrapperTo(wrapper io.Writer, destFolder string, dotfileName string) error {
+func (e executable) writeWrapperArgv(wrapperPath string, destFolder string) error {
+	if e.argv == nil {
+		return nil
+	}
 	r := newReplacements(destDirPattern, destFolder)
-
-	// Add the shebang
-	fmt.Fprintln(wrapper, "#! /bin/sh")
-
-	// Add the preceding lines if any
-	for _, line := range e.preLines {
-		fmt.Fprintf(wrapper, "%s\n", r.apply(line))
+	f, err := os.OpenFile(wrapperPath+".argv", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0440)
+	if err != nil {
+		return err
 	}
-
-	// Update the path to include the destination folder
-	var env map[string]string
-	if e.env == nil {
-		env = make(map[string]string)
-	} else {
-		env = e.env
+	defer f.Close()
+	for _, arg := range e.argv {
+		fmt.Fprintf(f, "%s\n", r.apply(arg))
 	}
+	return nil
+}
 
-	path, specified := env["PATH"]
-	if !specified {
-		path = "$PATH"
+func (e executable) writeWrapperEnvv(wrapperPath string, destFolder string) error {
+	r := newReplacements(destDirPattern, destFolder)
+	f, err := os.OpenFile(wrapperPath+".envv", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0440)
+	if err != nil {
+		return err
 	}
-	env["PATH"] = strings.Join([]string{destFolder, path}, ":")
+	defer f.Close()
 
-	var sortedEnvvars []string
-	for e := range env {
-		sortedEnvvars = append(sortedEnvvars, e)
+	// Update PATH to insert the destination folder at the head.
+	var envm map[string]string
+	if e.envm == nil {
+		envm = make(map[string]string)
+	} else {
+		envm = e.envm
 	}
-	sort.Strings(sortedEnvvars)
-
-	for _, e := range sortedEnvvars {
-		v := env[e]
-		fmt.Fprintf(wrapper, "%s=%s \\\n", e, r.apply(v))
+	if path, ok := envm["PATH"]; ok {
+		envm["PATH"] = destFolder + ":" + path
+	} else {
+		// Replace PATH with <PATH, which instructs wrapper to insert the value at the head of a
+		// colon-separated environment variable list.
+		delete(envm, "PATH")
+		envm["<PATH"] = destFolder
 	}
-	// Add the call to the target executable
-	fmt.Fprintf(wrapper, "%s \\\n", dotfileName)
 
-	// Insert additional lines in the `arg` list
-	for _, line := range e.argLines {
-		fmt.Fprintf(wrapper, "\t%s \\\n", r.apply(line))
+	var envv []string
+	for k, v := range envm {
+		envv = append(envv, k+"="+r.apply(v))
+	}
+	sort.Strings(envv)
+	for _, e := range envv {
+		fmt.Fprintf(f, "%s\n", e)
 	}
-	// Add the script arguments "$@"
-	fmt.Fprintln(wrapper, "\t\"$@\"")
-
 	return nil
 }
 
diff --git a/tools/container/toolkit/executable_test.go b/tools/container/toolkit/executable_test.go
index 572ee2bba..9257cc5c3 100644
--- a/tools/container/toolkit/executable_test.go
+++ b/tools/container/toolkit/executable_test.go
@@ -17,102 +17,102 @@
 package main
 
 import (
-	"bytes"
+	"bufio"
+	"fmt"
+	"io/fs"
 	"os"
 	"path/filepath"
-	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/require"
 )
 
 func TestWrapper(t *testing.T) {
-	const shebang = "#! /bin/sh"
-	const destFolder = "/dest/folder"
-	const dotfileName = "source.real"
+	createTestWrapperProgram(t)
 
 	testCases := []struct {
-		e             executable
-		expectedLines []string
+		e            executable
+		expectedArgv []string
+		expectedEnvv []string
 	}{
 		{
-			e: executable{},
-			expectedLines: []string{
-				shebang,
-				"PATH=/dest/folder:$PATH \\",
-				"source.real \\",
-				"\t\"$@\"",
-				"",
+			e: executable{source: "source"},
+			expectedEnvv: []string{
+				fmt.Sprintf("<PATH=%s", destDirPattern),
 			},
 		},
 		{
 			e: executable{
-				env: map[string]string{
-					"PATH": "some-path",
+				source: "source",
+				envm: map[string]string{
+					"FOO": "BAR",
 				},
 			},
-			expectedLines: []string{
-				shebang,
-				"PATH=/dest/folder:some-path \\",
-				"source.real \\",
-				"\t\"$@\"",
-				"",
+			expectedEnvv: []string{
+				fmt.Sprintf("<PATH=%s", destDirPattern),
+				"FOO=BAR",
 			},
 		},
 		{
 			e: executable{
-				preLines: []string{
-					"preline1",
-					"preline2",
+				source: "source",
+				envm: map[string]string{
+					"PATH": "some-path",
+					"FOO":  "BAR",
 				},
 			},
-			expectedLines: []string{
-				shebang,
-				"preline1",
-				"preline2",
-				"PATH=/dest/folder:$PATH \\",
-				"source.real \\",
-				"\t\"$@\"",
-				"",
+			expectedEnvv: []string{
+				"FOO=BAR",
+				fmt.Sprintf("PATH=%s:some-path", destDirPattern),
 			},
 		},
 		{
 			e: executable{
-				argLines: []string{
-					"argline1",
-					"argline2",
+				source: "source",
+				argv: []string{
+					"argb",
+					"arga",
+					"argc",
 				},
 			},
-			expectedLines: []string{
-				shebang,
-				"PATH=/dest/folder:$PATH \\",
-				"source.real \\",
-				"\targline1 \\",
-				"\targline2 \\",
-				"\t\"$@\"",
-				"",
+			expectedArgv: []string{
+				"argb",
+				"arga",
+				"argc",
+			},
+			expectedEnvv: []string{
+				fmt.Sprintf("<PATH=%s", destDirPattern),
 			},
 		},
 	}
 
-	for i, tc := range testCases {
-		buf := &bytes.Buffer{}
-
-		err := tc.e.writeWrapperTo(buf, destFolder, dotfileName)
+	for _, tc := range testCases {
+		destFolder := t.TempDir()
+		r := newReplacements(destDirPattern, destFolder)
+		for k, v := range tc.expectedEnvv {
+			tc.expectedEnvv[k] = r.apply(v)
+		}
+		path, err := tc.e.installWrapper(destFolder)
 		require.NoError(t, err)
-
-		exepectedContents := strings.Join(tc.expectedLines, "\n")
-		require.Equal(t, exepectedContents, buf.String(), "%v: %v", i, tc)
+		require.FileExists(t, path)
+		envv, err := readAllLines(path + ".envv")
+		require.NoError(t, err)
+		require.Equal(t, tc.expectedEnvv, envv)
+		argv, err := readAllLines(path + ".argv")
+		if tc.expectedArgv == nil {
+			require.ErrorAs(t, err, &fs.ErrNotExist)
+		} else {
+			require.Equal(t, tc.expectedArgv, argv)
+
+		}
 	}
 }
 
 func TestInstallExecutable(t *testing.T) {
-	inputFolder, err := os.MkdirTemp("", "")
-	require.NoError(t, err)
-	defer os.RemoveAll(inputFolder)
+	createTestWrapperProgram(t)
 
 	// Create the source file
-	source := filepath.Join(inputFolder, "input")
+	source := filepath.Join(t.TempDir(), "input")
 	sourceFile, err := os.Create(source)
 
 	base := filepath.Base(source)
@@ -123,7 +123,6 @@ func TestInstallExecutable(t *testing.T) {
 	e := executable{
 		source: source,
 		target: executableTarget{
-			dotfileName: "input.real",
 			wrapperName: "input",
 		},
 	}
@@ -150,3 +149,31 @@ func TestInstallExecutable(t *testing.T) {
 	require.NoError(t, err)
 	require.NotEqual(t, 0, wrapperInfo.Mode()&0111)
 }
+
+func createTestWrapperProgram(t *testing.T) {
+	t.Helper()
+	currentExe, err := os.Executable()
+	if err != nil {
+		t.Fatalf("error getting current executable: %v", err)
+	}
+	wrapperPath := filepath.Join(filepath.Dir(currentExe), "wrapper")
+	f, err := os.Create(wrapperPath)
+	if err != nil {
+		t.Fatalf("error creating test wrapper: %v", err)
+	}
+	f.Close()
+}
+
+func readAllLines(path string) (s []string, err error) {
+	f, err := os.Open(path)
+	if err != nil {
+		return
+	}
+	defer f.Close()
+	scanner := bufio.NewScanner(f)
+	for scanner.Scan() {
+		s = append(s, scanner.Text())
+	}
+	err = scanner.Err()
+	return
+}
diff --git a/tools/container/toolkit/runtime.go b/tools/container/toolkit/runtime.go
index d2e0b69f8..ceed1815b 100644
--- a/tools/container/toolkit/runtime.go
+++ b/tools/container/toolkit/runtime.go
@@ -48,37 +48,22 @@ func installContainerRuntimes(toolkitDir string, driverRoot string) error {
 // created to allow for the configuration of the runtime environment.
 func newNvidiaContainerRuntimeInstaller(source string) *executable {
 	wrapperName := filepath.Base(source)
-	dotfileName := wrapperName + ".real"
 	target := executableTarget{
-		dotfileName: dotfileName,
 		wrapperName: wrapperName,
 	}
 	return newRuntimeInstaller(source, target, nil)
 }
 
 func newRuntimeInstaller(source string, target executableTarget, env map[string]string) *executable {
-	preLines := []string{
-		"",
-		"cat /proc/modules | grep -e \"^nvidia \" >/dev/null 2>&1",
-		"if [ \"${?}\" != \"0\" ]; then",
-		"	echo \"nvidia driver modules are not yet loaded, invoking runc directly\"",
-		"	exec runc \"$@\"",
-		"fi",
-		"",
-	}
-
 	runtimeEnv := make(map[string]string)
 	runtimeEnv["XDG_CONFIG_HOME"] = filepath.Join(destDirPattern, ".config")
 	for k, v := range env {
 		runtimeEnv[k] = v
 	}
-
 	r := executable{
-		source:   source,
-		target:   target,
-		env:      runtimeEnv,
-		preLines: preLines,
+		source: source,
+		target: target,
+		envm:   runtimeEnv,
 	}
-
 	return &r
 }
diff --git a/tools/container/toolkit/runtime_test.go b/tools/container/toolkit/runtime_test.go
index 61fa8b9e8..4c03a7a92 100644
--- a/tools/container/toolkit/runtime_test.go
+++ b/tools/container/toolkit/runtime_test.go
@@ -17,8 +17,7 @@
 package main
 
 import (
-	"bytes"
-	"strings"
+	"path/filepath"
 	"testing"
 
 	"github.com/stretchr/testify/require"
@@ -26,32 +25,10 @@ import (
 
 func TestNvidiaContainerRuntimeInstallerWrapper(t *testing.T) {
 	r := newNvidiaContainerRuntimeInstaller(nvidiaContainerRuntimeSource)
-
-	const shebang = "#! /bin/sh"
-	const destFolder = "/dest/folder"
-	const dotfileName = "source.real"
-
-	buf := &bytes.Buffer{}
-
-	err := r.writeWrapperTo(buf, destFolder, dotfileName)
-	require.NoError(t, err)
-
-	expectedLines := []string{
-		shebang,
-		"",
-		"cat /proc/modules | grep -e \"^nvidia \" >/dev/null 2>&1",
-		"if [ \"${?}\" != \"0\" ]; then",
-		"	echo \"nvidia driver modules are not yet loaded, invoking runc directly\"",
-		"	exec runc \"$@\"",
-		"fi",
-		"",
-		"PATH=/dest/folder:$PATH \\",
-		"XDG_CONFIG_HOME=/dest/folder/.config \\",
-		"source.real \\",
-		"\t\"$@\"",
-		"",
-	}
-
-	exepectedContents := strings.Join(expectedLines, "\n")
-	require.Equal(t, exepectedContents, buf.String())
+	require.Equal(t, nvidiaContainerRuntimeSource, r.source)
+	require.Equal(t, filepath.Base(nvidiaContainerRuntimeSource), r.target.wrapperName)
+	require.Equal(t, filepath.Base(nvidiaContainerRuntimeSource), r.wrapperName())
+	require.Equal(t, filepath.Base(nvidiaContainerRuntimeSource)+".real", r.dotRealFilename())
+	require.Nil(t, r.argv)
+	require.Equal(t, map[string]string{"XDG_CONFIG_HOME": filepath.Join(destDirPattern, ".config")}, r.envm)
 }
diff --git a/tools/container/toolkit/toolkit.go b/tools/container/toolkit/toolkit.go
index 8175ed4e8..e03690ec7 100644
--- a/tools/container/toolkit/toolkit.go
+++ b/tools/container/toolkit/toolkit.go
@@ -572,7 +572,6 @@ func installContainerToolkitCLI(toolkitDir string) (string, error) {
 	e := executable{
 		source: "/usr/bin/nvidia-ctk",
 		target: executableTarget{
-			dotfileName: "nvidia-ctk.real",
 			wrapperName: "nvidia-ctk",
 		},
 	}
@@ -585,7 +584,6 @@ func installContainerCDIHookCLI(toolkitDir string) (string, error) {
 	e := executable{
 		source: "/usr/bin/nvidia-cdi-hook",
 		target: executableTarget{
-			dotfileName: "nvidia-cdi-hook.real",
 			wrapperName: "nvidia-cdi-hook",
 		},
 	}
@@ -598,17 +596,16 @@ func installContainerCDIHookCLI(toolkitDir string) (string, error) {
 func installContainerCLI(toolkitRoot string) (string, error) {
 	log.Infof("Installing NVIDIA container CLI from '%v'", nvidiaContainerCliSource)
 
-	env := map[string]string{
+	envm := map[string]string{
 		"LD_LIBRARY_PATH": toolkitRoot,
 	}
 
 	e := executable{
 		source: nvidiaContainerCliSource,
 		target: executableTarget{
-			dotfileName: "nvidia-container-cli.real",
 			wrapperName: "nvidia-container-cli",
 		},
-		env: env,
+		envm: envm,
 	}
 
 	installedPath, err := e.install(toolkitRoot)
@@ -623,17 +620,12 @@ func installContainerCLI(toolkitRoot string) (string, error) {
 func installRuntimeHook(toolkitRoot string, configFilePath string) (string, error) {
 	log.Infof("Installing NVIDIA container runtime hook from '%v'", nvidiaContainerRuntimeHookSource)
 
-	argLines := []string{
-		fmt.Sprintf("-config \"%s\"", configFilePath),
-	}
-
 	e := executable{
 		source: nvidiaContainerRuntimeHookSource,
 		target: executableTarget{
-			dotfileName: "nvidia-container-runtime-hook.real",
 			wrapperName: "nvidia-container-runtime-hook",
 		},
-		argLines: argLines,
+		argv: []string{"-config", configFilePath},
 	}
 
 	installedPath, err := e.install(toolkitRoot)