Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 215 additions & 1 deletion internal/pkg/object/command/ecs/ecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package ecs

import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the idea is really nice, and I like how you did that

ct "context"
"embed"
"encoding/json"
"fmt"
"os"
"strings"
"text/template"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -21,6 +23,118 @@ import (
"github.com/patterninc/heimdall/pkg/result/column"
)

//go:embed startup_script_template.sh
var startupScriptTemplate embed.FS

// FileDownload represents a file to be downloaded before container execution
type FileDownload struct {
Source string `yaml:"source,omitempty" json:"source,omitempty"` // S3 URI or HTTP URL
Destination string `yaml:"destination,omitempty" json:"destination,omitempty"` // Local path in container
}

// StartupScriptConfig represents configuration for the startup script
type StartupScriptConfig struct {
ScriptPath string `yaml:"script_path,omitempty" json:"script_path,omitempty"` // Path to the startup script
DownloadDir string `yaml:"download_dir,omitempty" json:"download_dir,omitempty"` // Directory to download files to
Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // Timeout in seconds
CreateDirs bool `yaml:"create_dirs,omitempty" json:"create_dirs,omitempty"` // Create destination directories
}

// ScriptTemplateData represents data for populating the startup script template
type ScriptTemplateData struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try to make types private if they are not used outside of the package.

DownloadDir string
Timeout int
CreateDirs bool
Downloads []ScriptDownload
}

// ScriptDownload represents a download item for the script template
type ScriptDownload struct {
Source string
Destination string
IsS3 bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add enum which will show location for download. If we add gsc,azure tomorrow we'll not need to add more and more variables

}

// ContainerModificationOption represents a generic option for modifying container definitions
type ContainerModificationOption func(*types.ContainerDefinition) error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move options to separate file, it makes code more cleaner and easier for read.
WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be private object


// ContainerOption represents a generic option for container modifications
type ContainerOption struct {
ModifyContainer ContainerModificationOption
Description string
}

// WithStartupScriptWrapper creates a container option that injects startup script for file downloads
func (execCtx *executionContext) WithStartupScriptWrapper() ContainerOption {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't mix object definitions and functions.
Structure file in the way that on top we have object definitions and functions lower

return ContainerOption{
ModifyContainer: func(container *types.ContainerDefinition) error {
return execCtx.modifyContainerWithStartupScript(container)
},
Description: "Inject startup script for file downloads",
}
}

// modifyContainerWithStartupScript modifies a container definition to include startup script for downloads
func (execCtx *executionContext) modifyContainerWithStartupScript(container *types.ContainerDefinition) error {
if len(execCtx.FileDownloads) == 0 {
return nil // No downloads configured, no modification needed
}

// Generate startup script
startupScript, err := generateStartupScript(execCtx.FileDownloads, execCtx.StartupScriptConfig)
if err != nil {
return fmt.Errorf("failed to generate startup script: %w", err)
}

// Store original command
originalCommand := container.Command
if originalCommand == nil {
originalCommand = []string{}
}

// Create startup script command
scriptCmd := []string{"sh", "-c", startupScript}
container.Command = scriptCmd

// Add environment variables for the startup script
if container.Environment == nil {
container.Environment = []types.KeyValuePair{}
}

// Add original command as environment variable for the startup script
originalCmdStr := strings.Join(originalCommand, " ")
container.Environment = append(container.Environment,
types.KeyValuePair{
Name: aws.String("ORIGINAL_COMMAND"),
Value: aws.String(originalCmdStr),
})
return nil
}

// getDefaultContainerOptions returns the default container options
func (execCtx *executionContext) getDefaultContainerOptions() []ContainerOption {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not default container options. It's just container options

Copy link
Contributor

@hladush hladush Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit; let's consider to move this function to the buildExecutionContext() and after that call getOptions. don't create options during registerTaskDefinition

The major motivation for that from my point of view is that if we failed to build any option in the future we receive errors on the buildExecutionContext step, not on registrating them

options := []ContainerOption{}

// Add startup script wrapper option if file downloads are configured
if len(execCtx.FileDownloads) > 0 {
options = append(options, execCtx.WithStartupScriptWrapper())
}

return options
}

// applyContainerOptions applies container options to container definitions
func (execCtx *executionContext) applyContainerOptions(containerDefinitions []types.ContainerDefinition, options []ContainerOption) error {
for _, option := range options {
for i := range containerDefinitions {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for _, cd := range containerDefinitions { here and after that option.ModifyContainer(cd)

if err := option.ModifyContainer(&containerDefinitions[i]); err != nil {
return fmt.Errorf("failed to apply container option '%s': %w", option.Description, err)
}
}
}
return nil
}

// ECS command context structure
type ecsCommandContext struct {
TaskDefinitionTemplate string `yaml:"task_definition_template,omitempty" json:"task_definition_template,omitempty"`
Expand All @@ -29,6 +143,10 @@ type ecsCommandContext struct {
PollingInterval duration.Duration `yaml:"polling_interval,omitempty" json:"polling_interval,omitempty"`
Timeout duration.Duration `yaml:"timeout,omitempty" json:"timeout,omitempty"`
MaxFailCount int `yaml:"max_fail_count,omitempty" json:"max_fail_count,omitempty"` // max failures before giving up

// File download configuration
FileDownloads []FileDownload `yaml:"file_downloads,omitempty" json:"file_downloads,omitempty"`
StartupScriptConfig *StartupScriptConfig `yaml:"startup_script_config,omitempty" json:"startup_script_config,omitempty"`
}

// ECS cluster context structure
Expand Down Expand Up @@ -77,6 +195,10 @@ type executionContext struct {
Timeout duration.Duration `json:"timeout"`
MaxFailCount int `json:"max_fail_count"`

// File download configuration
FileDownloads []FileDownload `json:"file_downloads"`
StartupScriptConfig *StartupScriptConfig `json:"startup_script_config"`

ecsClient *ecs.Client
taskDefARN *string
tasks map[string]*taskTracker
Expand All @@ -87,6 +209,8 @@ const (
defaultTaskTimeout = duration.Duration(1 * time.Hour)
defaultMaxFailCount = 1
defaultTaskCount = 1
defaultDownloadDir = "/tmp/downloads"
defaultTimeout = 300
startedByPrefix = "heimdall-job-"
errMaxFailCount = "task %s failed %d times (max: %d), giving up"
errPollingTimeout = "polling timed out for arns %v after %v"
Expand All @@ -97,13 +221,77 @@ var (
errMissingTemplate = fmt.Errorf("task definition template is required")
)

// generateStartupScript creates a startup script for downloading files using a template
func generateStartupScript(fileDownloads []FileDownload, config *StartupScriptConfig) (string, error) {
if len(fileDownloads) == 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need that check because you've already checked it on line 79

return "#!/bin/bash\necho 'No files to download'\nexec \"$@\"", nil
}

downloadDir := defaultDownloadDir
timeout := defaultTimeout
createDirs := true

if config != nil {
if config.DownloadDir != "" {
downloadDir = config.DownloadDir
}
if config.Timeout > 0 {
timeout = config.Timeout
}
createDirs = config.CreateDirs
}

// Prepare template data
templateData := ScriptTemplateData{
DownloadDir: downloadDir,
Timeout: timeout,
CreateDirs: createDirs,
Downloads: make([]ScriptDownload, 0, len(fileDownloads)),
}

// Convert file downloads to script downloads
for _, download := range fileDownloads {
scriptDownload := ScriptDownload{
Source: download.Source,
Destination: download.Destination,
IsS3: strings.HasPrefix(download.Source, "s3://"),
}
templateData.Downloads = append(templateData.Downloads, scriptDownload)
}

// Load template from embedded filesystem
templateContent, err := startupScriptTemplate.ReadFile("startup_script_template.sh")
if err != nil {
return "", fmt.Errorf("failed to read embedded template file: %w", err)
}

// Parse template
tmpl, err := template.New("startup_script").Parse(string(templateContent))
if err != nil {
return "", fmt.Errorf("failed to parse template: %w", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider to start emit this metrics like in this PR
On low leve functions we only emmit metrics on top level we add logs as well

}

// Execute template
var script strings.Builder
if err := tmpl.Execute(&script, templateData); err != nil {
return "", fmt.Errorf("failed to execute template: %w", err)
}

return script.String(), nil
}

func New(commandContext *context.Context) (plugin.Handler, error) {

e := &ecsCommandContext{
PollingInterval: defaultPollingInterval,
Timeout: defaultTaskTimeout,
MaxFailCount: defaultMaxFailCount,
TaskCount: defaultTaskCount,
StartupScriptConfig: &StartupScriptConfig{
DownloadDir: defaultDownloadDir,
Timeout: defaultTimeout,
CreateDirs: true,
},
}

if commandContext != nil {
Expand Down Expand Up @@ -151,6 +339,15 @@ func (e *ecsCommandContext) handler(r *plugin.Runtime, job *job.Job, cluster *cl

// prepare and register task definition with ECS
func (execCtx *executionContext) registerTaskDefinition() error {
// Start with the original container definitions
containerDefinitions := execCtx.TaskDefinitionWrapper.TaskDefinition.ContainerDefinitions

// Apply container options using the options pattern
containerOptions := execCtx.getDefaultContainerOptions()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have containerOptions inside execCtx and you call applyContainerOptions on execCtx don't pass containerOptions to the method call.

if err := execCtx.applyContainerOptions(containerDefinitions, containerOptions); err != nil {
return fmt.Errorf("failed to apply container options: %w", err)
}

registerInput := &ecs.RegisterTaskDefinitionInput{
Family: aws.String(aws.ToString(execCtx.TaskDefinitionWrapper.TaskDefinition.Family)),
RequiresCompatibilities: []types.Compatibility{types.CompatibilityFargate},
Expand All @@ -159,7 +356,7 @@ func (execCtx *executionContext) registerTaskDefinition() error {
Memory: aws.String(fmt.Sprintf("%d", execCtx.ClusterConfig.Memory)),
ExecutionRoleArn: aws.String(execCtx.ClusterConfig.ExecutionRoleARN),
TaskRoleArn: aws.String(execCtx.ClusterConfig.TaskRoleARN),
ContainerDefinitions: execCtx.TaskDefinitionWrapper.TaskDefinition.ContainerDefinitions,
ContainerDefinitions: containerDefinitions,
}

registerOutput, err := execCtx.ecsClient.RegisterTaskDefinition(ctx, registerInput)
Expand Down Expand Up @@ -373,6 +570,23 @@ func validateExecutionContext(ctx *executionContext) error {
return fmt.Errorf("task count (%d) needs to be greater than 0 and less than cluster max task count (%d)", ctx.TaskCount, ctx.ClusterConfig.MaxTaskCount)
}

// Validate file downloads configuration
for i, download := range ctx.FileDownloads {
if download.Source == "" {
return fmt.Errorf("file download %d: source is required", i)
}
if download.Destination == "" {
return fmt.Errorf("file download %d: destination is required", i)
}
}

// Validate startup script configuration
if ctx.StartupScriptConfig != nil {
if ctx.StartupScriptConfig.Timeout < 0 {
return fmt.Errorf("timeout cannot be negative")
}
}

return nil

}
Expand Down
47 changes: 47 additions & 0 deletions internal/pkg/object/command/ecs/startup_script_template.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to hardcode 1 script for every user of heimdall or we can add more flexibility provide predefined template but allows task, to override path to the script.
@babourine WDYT?

set -e
echo 'Starting file downloads to {{.DownloadDir}}...'
{{if .CreateDirs}}mkdir -p {{.DownloadDir}}{{end}}

if ! command -v aws &> /dev/null; then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check should be run if file is S3

echo 'Installing AWS CLI...'
if apk update && apk add aws-cli; then
echo 'AWS CLI installed successfully'
else
echo 'ERROR: Failed to install AWS CLI'
exit 1
fi
fi

{{range .Downloads}}
# Download: {{.Source}}
mkdir -p $(dirname {{.Destination}})
{{if .IsS3}}
echo "Downloading from S3: {{.Source}}"
if aws s3 cp '{{.Source}}' '{{.Destination}}' --cli-read-timeout {{$.Timeout}} --cli-connect-timeout {{$.Timeout}}; then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move downloading from s3 in separate function

echo "Successfully downloaded: {{.Source}}"

if [ -f '{{.Destination}}' ] && [ -s '{{.Destination}}' ]; then
echo "File verification passed: {{.Destination}}"
file_size=$(stat -c%s '{{.Destination}}' 2>/dev/null || echo "unknown")
echo "File size: $file_size bytes"
else
echo "ERROR: Downloaded file is empty or missing: {{.Destination}}"
exit 1
fi
else
echo "ERROR: Failed to download from S3: {{.Source}}"
exit 1
fi
{{end}}
{{end}}
echo 'All files downloaded successfully'
echo 'Starting main application...'
# Execute the original command
if [ -n "$ORIGINAL_COMMAND" ]; then
echo "Executing: $ORIGINAL_COMMAND"
exec "$ORIGINAL_COMMAND"
else
echo "No original command found, executing: $@"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should add error here

exec "$@"
fi
Loading