diff --git a/tests/integration/ray_test.go b/tests/integration/ray_test.go index 69fcbad5..41a35581 100644 --- a/tests/integration/ray_test.go +++ b/tests/integration/ray_test.go @@ -26,6 +26,7 @@ import ( support "github.com/project-codeflare/codeflare-operator/test/support" rayv1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1" + batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -116,12 +117,23 @@ func TestRayCluster(t *testing.T) { defer support.WriteRayJobLogs(test, rayClient, rayJob.Namespace, rayJob.Name) test.T().Logf("Waiting for RayJob %s/%s to complete", rayJob.Namespace, rayJob.Name) - test.Eventually(support.RayJob(test, rayJob.Namespace, rayJob.Name), support.TestTimeoutLong). + // Will be removed when RHODS-12857 is fixed + test.Eventually(Job(test, rayJob.Namespace, rayJob.Name), support.TestTimeoutLong). + Should(WithTransform(func(job *batchv1.Job) int { + return int(job.Status.Succeeded) + }, Equal(1))) + + /* + Should be uncommented when RHODS-12857 is fixed + + test.T().Logf("Waiting for RayJob %s/%s to complete", rayJob.Namespace, rayJob.Name) + test.Eventually(support.RayJob(test, rayJob.Namespace, rayJob.Name), support.TestTimeoutLong). Should(WithTransform(support.RayJobStatus, Satisfy(rayv1alpha1.IsJobTerminal))) - // Assert the Ray job has completed successfully - test.Expect(support.GetRayJob(test, rayJob.Namespace, rayJob.Name)). - To(WithTransform(support.RayJobStatus, Equal(rayv1alpha1.JobStatusSucceeded))) + // Assert the Ray job has completed successfully + test.Expect(support.GetRayJob(test, rayJob.Namespace, rayJob.Name)). + To(WithTransform(support.RayJobStatus, Equal(rayv1alpha1.JobStatusSucceeded))) + */ } func TestRayJobSubmissionRest(t *testing.T) { @@ -355,3 +367,12 @@ func createRayCluster(test support.Test, namespace string, mnist *corev1.ConfigM Should(WithTransform(support.RayClusterState, Equal(rayv1alpha1.Ready))) return } + +func Job(t support.Test, namespace, name string) func(g Gomega) *batchv1.Job { + return func(g Gomega) *batchv1.Job { + + job, err := t.Client().Core().BatchV1().Jobs(namespace).Get(t.Ctx(), name, metav1.GetOptions{}) + g.Expect(err).NotTo(HaveOccurred()) + return job + } +}