Skip to content

Commit 1a503f8

Browse files
jdomingrJuan Dominguez
andauthored
feat(genai): add batch prediction samples (1) and update GenAI SDK (#10174)
* feat(genai): add new batch prediction samples and update SDK * refactor: change hashset for enumSet and now use TimeUnit instead if thread.sleep * refactor: change tests to ensure that the API is called and not mocked * refactor: change batch polling logic and return type * fix lint * chore(genai): update comments and SDK version * chore(genai): update batch prediction test to use mocking * refactor(genai): change polling logic and update tests --------- Co-authored-by: Juan Dominguez <[email protected]>
1 parent 3dd698d commit 1a503f8

File tree

4 files changed

+386
-1
lines changed

4 files changed

+386
-1
lines changed

genai/snippets/pom.xml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@
5151
<dependency>
5252
<groupId>com.google.genai</groupId>
5353
<artifactId>google-genai</artifactId>
54-
<version>1.15.0</version>
54+
<version>1.23.0</version>
55+
</dependency>
56+
<dependency>
57+
<groupId>com.google.cloud</groupId>
58+
<artifactId>google-cloud-storage</artifactId>
59+
<scope>test</scope>
5560
</dependency>
5661
<dependency>
5762
<groupId>com.google.cloud</groupId>
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package genai.batchprediction;
18+
19+
// [START googlegenaisdk_batchpredict_embeddings_with_gcs]
20+
21+
import static com.google.genai.types.JobState.Known.JOB_STATE_CANCELLED;
22+
import static com.google.genai.types.JobState.Known.JOB_STATE_FAILED;
23+
import static com.google.genai.types.JobState.Known.JOB_STATE_PAUSED;
24+
import static com.google.genai.types.JobState.Known.JOB_STATE_SUCCEEDED;
25+
26+
import com.google.genai.Client;
27+
import com.google.genai.types.BatchJob;
28+
import com.google.genai.types.BatchJobDestination;
29+
import com.google.genai.types.BatchJobSource;
30+
import com.google.genai.types.CreateBatchJobConfig;
31+
import com.google.genai.types.GetBatchJobConfig;
32+
import com.google.genai.types.HttpOptions;
33+
import com.google.genai.types.JobState;
34+
import java.util.EnumSet;
35+
import java.util.Set;
36+
import java.util.concurrent.TimeUnit;
37+
38+
public class BatchPredictionEmbeddingsWithGcs {
39+
40+
public static void main(String[] args) throws InterruptedException {
41+
// TODO(developer): Replace these variables before running the sample.
42+
String modelId = "text-embedding-005";
43+
String outputGcsUri = "gs://your-bucket/your-prefix";
44+
createBatchJob(modelId, outputGcsUri);
45+
}
46+
47+
// Creates a batch prediction job with embedding model and Google Cloud Storage.
48+
public static JobState createBatchJob(String modelId, String outputGcsUri)
49+
throws InterruptedException {
50+
// Client Initialization. Once created, it can be reused for multiple requests.
51+
try (Client client =
52+
Client.builder()
53+
.location("us-central1")
54+
.vertexAI(true)
55+
.httpOptions(HttpOptions.builder().apiVersion("v1").build())
56+
.build()) {
57+
58+
// See the documentation:
59+
// https://googleapis.github.io/java-genai/javadoc/com/google/genai/Batches.html
60+
BatchJobSource batchJobSource =
61+
BatchJobSource.builder()
62+
// Source link:
63+
// https://storage.cloud.google.com/cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl
64+
.gcsUri("gs://cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl")
65+
.format("jsonl")
66+
.build();
67+
68+
CreateBatchJobConfig batchJobConfig =
69+
CreateBatchJobConfig.builder()
70+
.displayName("your-display-name")
71+
.dest(BatchJobDestination.builder().gcsUri(outputGcsUri).format("jsonl").build())
72+
.build();
73+
74+
BatchJob batchJob = client.batches.create(modelId, batchJobSource, batchJobConfig);
75+
76+
String jobName =
77+
batchJob.name().orElseThrow(() -> new IllegalStateException("Missing job name"));
78+
JobState jobState =
79+
batchJob.state().orElseThrow(() -> new IllegalStateException("Missing job state"));
80+
System.out.println("Job name: " + jobName);
81+
System.out.println("Job state: " + jobState);
82+
// Job name: projects/.../locations/.../batchPredictionJobs/6205497615459549184
83+
// Job state: JOB_STATE_PENDING
84+
85+
// See the documentation:
86+
// https://googleapis.github.io/java-genai/javadoc/com/google/genai/types/BatchJob.html
87+
Set<JobState.Known> completedStates =
88+
EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED);
89+
90+
while (!completedStates.contains(jobState.knownEnum())) {
91+
TimeUnit.SECONDS.sleep(30);
92+
batchJob = client.batches.get(jobName, GetBatchJobConfig.builder().build());
93+
jobState =
94+
batchJob
95+
.state()
96+
.orElseThrow(() -> new IllegalStateException("Missing job state during polling"));
97+
System.out.println("Job state: " + jobState);
98+
}
99+
// Example response:
100+
// Job state: JOB_STATE_QUEUED
101+
// Job state: JOB_STATE_RUNNING
102+
// ...
103+
// Job state: JOB_STATE_SUCCEEDED
104+
return jobState;
105+
}
106+
}
107+
}
108+
// [END googlegenaisdk_batchpredict_embeddings_with_gcs]
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package genai.batchprediction;
18+
19+
// [START googlegenaisdk_batchpredict_with_gcs]
20+
21+
import static com.google.genai.types.JobState.Known.JOB_STATE_CANCELLED;
22+
import static com.google.genai.types.JobState.Known.JOB_STATE_FAILED;
23+
import static com.google.genai.types.JobState.Known.JOB_STATE_PAUSED;
24+
import static com.google.genai.types.JobState.Known.JOB_STATE_SUCCEEDED;
25+
26+
import com.google.genai.Client;
27+
import com.google.genai.types.BatchJob;
28+
import com.google.genai.types.BatchJobDestination;
29+
import com.google.genai.types.BatchJobSource;
30+
import com.google.genai.types.CreateBatchJobConfig;
31+
import com.google.genai.types.GetBatchJobConfig;
32+
import com.google.genai.types.HttpOptions;
33+
import com.google.genai.types.JobState;
34+
import java.util.EnumSet;
35+
import java.util.Optional;
36+
import java.util.Set;
37+
import java.util.concurrent.TimeUnit;
38+
39+
public class BatchPredictionWithGcs {
40+
41+
public static void main(String[] args) throws InterruptedException {
42+
// TODO(developer): Replace these variables before running the sample.
43+
// To use a tuned model, set the model param to your tuned model using the following format:
44+
// modelId = "projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}
45+
String modelId = "gemini-2.5-flash";
46+
String outputGcsUri = "gs://your-bucket/your-prefix";
47+
createBatchJob(modelId, outputGcsUri);
48+
}
49+
50+
// Creates a batch prediction job with Google Cloud Storage.
51+
public static JobState createBatchJob(String modelId, String outputGcsUri)
52+
throws InterruptedException {
53+
// Client Initialization. Once created, it can be reused for multiple requests.
54+
try (Client client =
55+
Client.builder()
56+
.location("global")
57+
.vertexAI(true)
58+
.httpOptions(HttpOptions.builder().apiVersion("v1").build())
59+
.build()) {
60+
// See the documentation:
61+
// https://googleapis.github.io/java-genai/javadoc/com/google/genai/Batches.html
62+
BatchJobSource batchJobSource =
63+
BatchJobSource.builder()
64+
// Source link:
65+
// https://storage.cloud.google.com/cloud-samples-data/batch/prompt_for_batch_gemini_predict.jsonl
66+
.gcsUri("gs://cloud-samples-data/batch/prompt_for_batch_gemini_predict.jsonl")
67+
.format("jsonl")
68+
.build();
69+
70+
CreateBatchJobConfig batchJobConfig =
71+
CreateBatchJobConfig.builder()
72+
.displayName("your-display-name")
73+
.dest(BatchJobDestination.builder().gcsUri(outputGcsUri).format("jsonl").build())
74+
.build();
75+
76+
BatchJob batchJob = client.batches.create(modelId, batchJobSource, batchJobConfig);
77+
78+
String jobName =
79+
batchJob.name().orElseThrow(() -> new IllegalStateException("Missing job name"));
80+
JobState jobState =
81+
batchJob.state().orElseThrow(() -> new IllegalStateException("Missing job state"));
82+
System.out.println("Job name: " + jobName);
83+
System.out.println("Job state: " + jobState);
84+
// Job name: projects/.../locations/.../batchPredictionJobs/6205497615459549184
85+
// Job state: JOB_STATE_PENDING
86+
87+
// See the documentation:
88+
// https://googleapis.github.io/java-genai/javadoc/com/google/genai/types/BatchJob.html
89+
Set<JobState.Known> completedStates =
90+
EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED);
91+
92+
while (!completedStates.contains(jobState.knownEnum())) {
93+
TimeUnit.SECONDS.sleep(30);
94+
batchJob = client.batches.get(jobName, GetBatchJobConfig.builder().build());
95+
jobState =
96+
batchJob
97+
.state()
98+
.orElseThrow(() -> new IllegalStateException("Missing job state during polling"));
99+
System.out.println("Job state: " + jobState);
100+
}
101+
// Example response:
102+
// Job state: JOB_STATE_QUEUED
103+
// Job state: JOB_STATE_RUNNING
104+
// Job state: JOB_STATE_RUNNING
105+
// ...
106+
// Job state: JOB_STATE_SUCCEEDED
107+
return jobState;
108+
}
109+
}
110+
}
111+
// [END googlegenaisdk_batchpredict_with_gcs]
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package genai.batchprediction;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static com.google.common.truth.Truth.assertWithMessage;
21+
import static com.google.genai.types.JobState.Known.JOB_STATE_PENDING;
22+
import static com.google.genai.types.JobState.Known.JOB_STATE_RUNNING;
23+
import static com.google.genai.types.JobState.Known.JOB_STATE_SUCCEEDED;
24+
import static org.mockito.ArgumentMatchers.any;
25+
import static org.mockito.ArgumentMatchers.anyString;
26+
import static org.mockito.Mockito.RETURNS_SELF;
27+
import static org.mockito.Mockito.mock;
28+
import static org.mockito.Mockito.mockStatic;
29+
import static org.mockito.Mockito.times;
30+
import static org.mockito.Mockito.verify;
31+
import static org.mockito.Mockito.when;
32+
33+
import com.google.genai.Batches;
34+
import com.google.genai.Client;
35+
import com.google.genai.types.BatchJob;
36+
import com.google.genai.types.BatchJobSource;
37+
import com.google.genai.types.CreateBatchJobConfig;
38+
import com.google.genai.types.GetBatchJobConfig;
39+
import com.google.genai.types.JobState;
40+
import java.io.ByteArrayOutputStream;
41+
import java.io.PrintStream;
42+
import java.lang.reflect.Field;
43+
import java.util.Optional;
44+
import org.junit.After;
45+
import org.junit.Before;
46+
import org.junit.BeforeClass;
47+
import org.junit.Test;
48+
import org.junit.runner.RunWith;
49+
import org.junit.runners.JUnit4;
50+
import org.mockito.MockedStatic;
51+
52+
@RunWith(JUnit4.class)
53+
public class BatchPredictionIT {
54+
55+
private static final String GEMINI_FLASH = "gemini-2.5-flash";
56+
private static final String EMBEDDING_MODEL = "text-embedding-005";
57+
private static String jobName;
58+
private static String outputGcsUri;
59+
private ByteArrayOutputStream bout;
60+
private Batches mockedBatches;
61+
private MockedStatic<Client> mockedStatic;
62+
63+
// Check if the required environment variables are set.
64+
public static void requireEnvVar(String envVarName) {
65+
assertWithMessage(String.format("Missing environment variable '%s' ", envVarName))
66+
.that(System.getenv(envVarName))
67+
.isNotEmpty();
68+
}
69+
70+
@BeforeClass
71+
public static void checkRequirements() {
72+
requireEnvVar("GOOGLE_CLOUD_PROJECT");
73+
jobName = "projects/project_id/locations/us-central1/batchPredictionJobs/job_id";
74+
outputGcsUri = "gs://your-bucket/your-prefix";
75+
}
76+
77+
@Before
78+
public void setUp() throws NoSuchFieldException, IllegalAccessException {
79+
bout = new ByteArrayOutputStream();
80+
System.setOut(new PrintStream(bout));
81+
82+
// Arrange
83+
Client.Builder mockedBuilder = mock(Client.Builder.class, RETURNS_SELF);
84+
mockedBatches = mock(Batches.class);
85+
mockedStatic = mockStatic(Client.class);
86+
mockedStatic.when(Client::builder).thenReturn(mockedBuilder);
87+
Client mockedClient = mock(Client.class);
88+
when(mockedBuilder.build()).thenReturn(mockedClient);
89+
90+
// Using reflection because 'batches' is a final field and cannot be mocked directly.
91+
// This is brittle but necessary for testing this class structure.
92+
Field field = Client.class.getDeclaredField("batches");
93+
field.setAccessible(true);
94+
field.set(mockedClient, mockedBatches);
95+
96+
// Mock the sequence of job states to test the polling loop
97+
BatchJob pendingJob = mock(BatchJob.class);
98+
when(pendingJob.name()).thenReturn(Optional.of(jobName));
99+
when(pendingJob.state()).thenReturn(Optional.of(new JobState(JOB_STATE_PENDING)));
100+
101+
BatchJob runningJob = mock(BatchJob.class);
102+
when(runningJob.state()).thenReturn(Optional.of(new JobState(JOB_STATE_RUNNING)));
103+
104+
BatchJob succeededJob = mock(BatchJob.class);
105+
when(succeededJob.state()).thenReturn(Optional.of(new JobState(JOB_STATE_SUCCEEDED)));
106+
107+
when(mockedBatches.create(
108+
anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class)))
109+
.thenReturn(pendingJob);
110+
when(mockedBatches.get(anyString(), any(GetBatchJobConfig.class)))
111+
.thenReturn(runningJob, succeededJob);
112+
}
113+
114+
@After
115+
public void tearDown() {
116+
System.setOut(null);
117+
bout.reset();
118+
mockedStatic.close();
119+
}
120+
121+
@Test
122+
public void testBatchPredictionWithGcs() throws InterruptedException {
123+
// Act
124+
JobState response = BatchPredictionWithGcs.createBatchJob(GEMINI_FLASH, outputGcsUri);
125+
126+
// Assert
127+
verify(mockedBatches, times(1))
128+
.create(anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class));
129+
verify(mockedBatches, times(2)).get(anyString(), any(GetBatchJobConfig.class));
130+
131+
assertThat(response).isNotNull();
132+
assertThat(response.knownEnum()).isEqualTo(JOB_STATE_SUCCEEDED);
133+
134+
String output = bout.toString();
135+
assertThat(output).contains("Job name: " + jobName);
136+
assertThat(output).contains("Job state: JOB_STATE_PENDING");
137+
assertThat(output).contains("Job state: JOB_STATE_RUNNING");
138+
assertThat(output).contains("Job state: JOB_STATE_SUCCEEDED");
139+
}
140+
141+
@Test
142+
public void testBatchPredictionEmbeddingsWithGcs() throws InterruptedException {
143+
// Act
144+
JobState response =
145+
BatchPredictionEmbeddingsWithGcs.createBatchJob(EMBEDDING_MODEL, outputGcsUri);
146+
147+
// Assert
148+
verify(mockedBatches, times(1))
149+
.create(anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class));
150+
verify(mockedBatches, times(2)).get(anyString(), any(GetBatchJobConfig.class));
151+
152+
assertThat(response).isNotNull();
153+
assertThat(response.knownEnum()).isEqualTo(JOB_STATE_SUCCEEDED);
154+
155+
String output = bout.toString();
156+
assertThat(output).contains("Job name: " + jobName);
157+
assertThat(output).contains("Job state: JOB_STATE_PENDING");
158+
assertThat(output).contains("Job state: JOB_STATE_RUNNING");
159+
assertThat(output).contains("Job state: JOB_STATE_SUCCEEDED");
160+
}
161+
}

0 commit comments

Comments
 (0)