diff --git a/go.mod b/go.mod index cc73123..c86dd62 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,9 @@ require ( github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2/config v1.28.6 github.com/aws/aws-sdk-go-v2/credentials v1.17.47 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.42.0 + github.com/aws/aws-sdk-go-v2/service/emrcontainers v1.35.2 github.com/aws/aws-sdk-go-v2/service/glue v1.107.0 github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0 github.com/aws/aws-sdk-go-v2/service/sts v1.33.14 @@ -30,7 +32,6 @@ require ( github.com/apache/arrow-go/v18 v18.0.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 // indirect - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect diff --git a/go.sum b/go.sum index ec2e412..be68598 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.25 h1:r67ps7oHCYnflpgDy2LZU0MAQtQ github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.25/go.mod h1:GrGY+Q4fIokYLtjCVB/aFfCVL6hhGUFl8inD18fDalE= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.42.0 h1:EJXx6zb+lOe/Do2bO0d0dwVnIRGoP5J5xZ0BTn3LbqM= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.42.0/go.mod h1:yYaWRnVSPyAmexW5t7G3TcuYoalYfT+xQwzWsvtUQ7M= +github.com/aws/aws-sdk-go-v2/service/emrcontainers v1.35.2 h1:qe9+m2ZI6WPdGVQ42ar02lslj/K8bpXyBn5eu+hXZ2k= +github.com/aws/aws-sdk-go-v2/service/emrcontainers v1.35.2/go.mod h1:Td38+RbBfHRtHhNLdpJqFbLwC+rDEtjkHhh90wCIrIY= github.com/aws/aws-sdk-go-v2/service/glue v1.107.0 h1:iO8m70j8JNSJl9aDQNTjR3BNRYbLphnrhCidSfw+4g0= github.com/aws/aws-sdk-go-v2/service/glue v1.107.0/go.mod h1:6FqWCqW0Py6VOvY42NQyf9e7N+sNVnDEiHFklCCCoQc= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= diff --git a/internal/pkg/object/command/spark/spark.go b/internal/pkg/object/command/spark/spark.go new file mode 100644 index 0000000..572e067 --- /dev/null +++ b/internal/pkg/object/command/spark/spark.go @@ -0,0 +1,340 @@ +package spark + +import ( + ct "context" + "encoding/json" + "fmt" + "os" + "regexp" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/emrcontainers" + "github.com/aws/aws-sdk-go-v2/service/emrcontainers/types" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/babourine/x/pkg/set" + + "github.com/patterninc/heimdall/internal/pkg/context" + "github.com/patterninc/heimdall/internal/pkg/object/cluster" + "github.com/patterninc/heimdall/internal/pkg/object/job" + "github.com/patterninc/heimdall/internal/pkg/result" + "github.com/patterninc/heimdall/pkg/plugin" +) + +// spark represents the Spark command context +type sparkCommandContext struct { + QueriesURI string `yaml:"queries_uri,omitempty" json:"queries_uri,omitempty"` + ResultsURI string `yaml:"results_uri,omitempty" json:"results_uri,omitempty"` + LogsURI *string `yaml:"logs_uri,omitempty" json:"logs_uri,omitempty"` + WrapperURI *string `yaml:"wrapper_uri,omitempty" json:"wrapper_uri,omitempty"` + Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` +} + +// sparkJobContext represents the context for a spark job +type sparkJobContext struct { + Query string `yaml:"query,omitempty" json:"query,omitempty"` + Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` + ReturnResult bool `yaml:"return_result,omitempty" json:"return_result,omitempty"` +} + +// sparkClusterContext represents the context for a spark cluster +type sparkClusterContext struct { + ExecutionRoleArn *string `yaml:"execution_role_arn,omitempty" json:"execution_role_arn,omitempty"` + EMRReleaseLabel *string `yaml:"emr_release_label,omitempty" json:"emr_release_label,omitempty"` + RoleARN *string `yaml:"role_arn,omitempty" json:"role_arn,omitempty"` + Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` +} + +const ( + driverMemoryProperty = `spark.driver.memory` + jobCheckInterval = 5 // seconds + jobTimeout = (5 * 60 * 60) / jobCheckInterval // 5 hours + noStateDetails = `no state details provided` +) + +var ( + ctx = ct.Background() + sparkDefaults = aws.String(`spark-defaults`) + assumeRoleSession = aws.String("AssumeRoleSession") + runtimeStates = set.New([]types.JobRunState{types.JobRunStateCompleted, types.JobRunStateFailed, types.JobRunStateCancelled}) + rxS3 = regexp.MustCompile(`^s3://([^/]+)/(.*)$`) +) + +var ( + ErrUnknownCluster = fmt.Errorf(`unknown cluster`) + ErrJobCanceled = fmt.Errorf(`job canceled`) +) + +// New creates a new Spark plugin handler. +func New(commandContext *context.Context) (plugin.Handler, error) { + + s := &sparkCommandContext{} + + if commandContext != nil { + if err := commandContext.Unmarshal(s); err != nil { + return nil, err + } + } + + return s.handler, nil + +} + +// Handler for the Spark job submission. +func (s *sparkCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { + + // let's unmarshal job context + jobContext := &sparkJobContext{} + if j.Context != nil { + if err := j.Context.Unmarshal(jobContext); err != nil { + return err + } + } + + // let's unmarshal cluster context + clusterContext := &sparkClusterContext{} + if c.Context != nil { + if err := c.Context.Unmarshal(clusterContext); err != nil { + return err + } + } + + // let's prepare job properties + if jobContext.Properties == nil { + jobContext.Properties = make(map[string]string) + } + for k, v := range s.Properties { + if _, found := jobContext.Properties[k]; !found { + jobContext.Properties[k] = v + } + } + + // do we have driver memory setting in the job properties? + if value, found := jobContext.Properties[driverMemoryProperty]; found { + clusterContext.Properties[driverMemoryProperty] = value + delete(jobContext.Properties, driverMemoryProperty) + } + + // setting AWS client + awsConfig, err := config.LoadDefaultConfig(ctx) + if err != nil { + return err + } + + // let's set empty options function,... + assumeRoleOptions := func(_ *emrcontainers.Options) {} + + // ...and, if we have assume role ARN set, let's establish creds... + if clusterContext.RoleARN != nil { + + stsSvc := sts.NewFromConfig(awsConfig) + + assumeRoleOutput, err := stsSvc.AssumeRole(ctx, &sts.AssumeRoleInput{ + RoleArn: clusterContext.RoleARN, + RoleSessionName: assumeRoleSession, + }) + if err != nil { + return err + } + + assumeRoleOptions = func(o *emrcontainers.Options) { + o.Credentials = credentials.NewStaticCredentialsProvider( + *assumeRoleOutput.Credentials.AccessKeyId, + *assumeRoleOutput.Credentials.SecretAccessKey, + *assumeRoleOutput.Credentials.SessionToken, + ) + } + + } + + svc := emrcontainers.NewFromConfig(awsConfig, assumeRoleOptions) + + // let's get the cluster ID + clusterID, err := getClusterID(svc, c.Name) + if err != nil { + return err + } + + // let's set the result uri + resultURI := fmt.Sprintf("%s/%s", s.ResultsURI, j.ID) + + // upload query to s3 here... + queryURI := fmt.Sprintf("%s/%s/query.sql", s.QueriesURI, j.ID) + if err := uploadFileToS3(queryURI, jobContext.Query); err != nil { + return err + } + + // let's set job driver + jobDriver := &types.JobDriver{} + jobParameters := getSparkSqlParameters(jobContext.Properties) + if jobContext.ReturnResult { + + jobDriver.SparkSubmitJobDriver = &types.SparkSubmitJobDriver{ + EntryPoint: s.WrapperURI, + EntryPointArguments: []string{ + queryURI, + resultURI, + }, + SparkSubmitParameters: jobParameters, + } + + } else { + + jobDriver.SparkSqlJobDriver = &types.SparkSqlJobDriver{ + EntryPoint: &queryURI, + SparkSqlParameters: jobParameters, + } + + } + + // let's prepare job payload + jobPayload := &emrcontainers.StartJobRunInput{ + Name: aws.String(j.ID), + VirtualClusterId: clusterID, + ExecutionRoleArn: clusterContext.ExecutionRoleArn, + ReleaseLabel: clusterContext.EMRReleaseLabel, + JobDriver: jobDriver, + ConfigurationOverrides: &types.ConfigurationOverrides{ + ApplicationConfiguration: []types.Configuration{{ + Classification: sparkDefaults, + Properties: clusterContext.Properties, + }}, + MonitoringConfiguration: &types.MonitoringConfiguration{ + PersistentAppUI: types.PersistentAppUIEnabled, + S3MonitoringConfiguration: &types.S3MonitoringConfiguration{ + LogUri: s.LogsURI, + }, + }, + }, + } + + // record the payload so we could easier understand what was submitted + jobPayloadJSON, err := json.MarshalIndent(jobPayload, ``, ` `) + if err != nil { + return err + } + r.Stdout.WriteString(string(jobPayloadJSON) + "\n\n") + + // start the job + outputStartJobRun, err := svc.StartJobRun(ctx, jobPayload) + if err != nil { + return err + } + + // TODO: cleanup at some point, once the command is stable + r.Stdout.WriteString(fmt.Sprintf("Cluster Job ID: %v\n", *outputStartJobRun.Id)) + // spew.Fdump(r.Stdout, s, clusterContext, jobContext) + + // keep checking until job succeeded or failed... +timeoutLoop: + for i := 0; i < jobTimeout; i++ { + time.Sleep(jobCheckInterval * time.Second) + describeJobOutput, err := svc.DescribeJobRun(ctx, &emrcontainers.DescribeJobRunInput{ + Id: outputStartJobRun.Id, + VirtualClusterId: clusterID, + }) + if err != nil { + // TODO: log error if it's persistent + r.Stderr.WriteString(fmt.Sprintf("job error: %v", err)) + } + if describeJobOutput != nil { + // print state every ~30 seconds... + if state := describeJobOutput.JobRun.State; i%6 == 0 || runtimeStates.Has(state) { + printState(r.Stdout, state) + switch state { + case types.JobRunStateCompleted: + break timeoutLoop + case types.JobRunStateFailed: + stateDetails := noStateDetails + if sd := describeJobOutput.JobRun.StateDetails; sd != nil { + stateDetails = *sd + } + return fmt.Errorf("job failed [%v]: %v", describeJobOutput.JobRun.FailureReason, stateDetails) + case types.JobRunStateCancelled: + return ErrJobCanceled + } + } + } + } + + // TODO: set dummy result for now, return actual result... + if j.Result, err = result.FromAvro(resultURI); err != nil { + return err + } + + return nil + +} + +func getClusterID(svc *emrcontainers.Client, clusterName string) (*string, error) { + + // let's get the cluster ID + outputListClusters, err := svc.ListVirtualClusters(ctx, &emrcontainers.ListVirtualClustersInput{ + States: []types.VirtualClusterState{types.VirtualClusterStateRunning}, + ContainerProviderType: types.ContainerProviderTypeEks, + }) + if err != nil { + return nil, fmt.Errorf("failed to list virtual clusters: %w", err) + } + + for _, vc := range outputListClusters.VirtualClusters { + if *vc.Name == clusterName { + return vc.Id, nil + } + } + + return nil, fmt.Errorf("cluster %s: %w", clusterName, ErrUnknownCluster) + +} + +func getSparkSqlParameters(properties map[string]string) *string { + + conf := make([]string, 0, len(properties)) + + for k, v := range properties { + conf = append(conf, fmt.Sprintf("--conf %s=%s", k, v)) + } + + return aws.String(strings.Join(conf, ` `)) + +} + +func printState(stdout *os.File, state types.JobRunState) { + stdout.WriteString(fmt.Sprintf("%v - job is still running. latest status: %v\n", time.Now(), state)) +} + +func uploadFileToS3(fileURI, content string) error { + + // get bucket name and prefix + s3Parts := rxS3.FindAllStringSubmatch(fileURI, -1) + if len(s3Parts) == 0 || len(s3Parts[0]) < 3 { + return fmt.Errorf("unexpected queries key: %v", s3Parts) + } + + // upload file + awsConfig, err := config.LoadDefaultConfig(ctx) + if err != nil { + return err + } + + // Create an S3 client + svc := s3.NewFromConfig(awsConfig) + uploader := manager.NewUploader(svc) + + // Upload the string content to S3 + if _, err := uploader.Upload(ctx, &s3.PutObjectInput{ + Bucket: &s3Parts[0][1], + Key: &s3Parts[0][2], + Body: strings.NewReader(content), + }); err != nil { + return err + } + + return nil + +} diff --git a/plugins/spark/spark.go b/plugins/spark/spark.go new file mode 100644 index 0000000..e99e74e --- /dev/null +++ b/plugins/spark/spark.go @@ -0,0 +1,12 @@ +package main + +import ( + "github.com/patterninc/heimdall/internal/pkg/context" + "github.com/patterninc/heimdall/internal/pkg/object/command/spark" + "github.com/patterninc/heimdall/pkg/plugin" +) + +// New creates a new Spark plugin handler. +func New(commandContext *context.Context) (plugin.Handler, error) { + return spark.New(commandContext) +}