Skip to content

Commit 3dd698d

Browse files
jdomingrJuan Dominguez
andauthored
feat(genai): add tuning samples (#10176)
* feat(genai): add tuning samples * refactor(genai): change default values when creating the tuning job and add new case in testfile * refactor: change exception message --------- Co-authored-by: Juan Dominguez <[email protected]>
1 parent 04b2605 commit 3dd698d

File tree

5 files changed

+495
-0
lines changed

5 files changed

+495
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package genai.tuning;
18+
19+
// [START googlegenaisdk_tuning_job_create]
20+
21+
import static com.google.genai.types.JobState.Known.JOB_STATE_PENDING;
22+
import static com.google.genai.types.JobState.Known.JOB_STATE_RUNNING;
23+
24+
import com.google.genai.Client;
25+
import com.google.genai.types.CreateTuningJobConfig;
26+
import com.google.genai.types.GetTuningJobConfig;
27+
import com.google.genai.types.HttpOptions;
28+
import com.google.genai.types.JobState;
29+
import com.google.genai.types.TunedModel;
30+
import com.google.genai.types.TunedModelCheckpoint;
31+
import com.google.genai.types.TuningDataset;
32+
import com.google.genai.types.TuningJob;
33+
import com.google.genai.types.TuningValidationDataset;
34+
import java.util.Collections;
35+
import java.util.EnumSet;
36+
import java.util.List;
37+
import java.util.Optional;
38+
import java.util.Set;
39+
import java.util.concurrent.TimeUnit;
40+
41+
public class TuningJobCreate {
42+
43+
public static void main(String[] args) throws InterruptedException {
44+
// TODO(developer): Replace these variables before running the sample.
45+
String model = "gemini-2.5-flash";
46+
createTuningJob(model);
47+
}
48+
49+
// Shows how to create a supervised fine-tuning job using training and validation datasets
50+
public static String createTuningJob(String model) throws InterruptedException {
51+
// Client Initialization. Once created, it can be reused for multiple requests.
52+
try (Client client =
53+
Client.builder()
54+
.location("us-central1")
55+
.vertexAI(true)
56+
.httpOptions(HttpOptions.builder().apiVersion("v1beta1").build())
57+
.build()) {
58+
59+
String trainingDatasetUri =
60+
"gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_train_data.jsonl";
61+
TuningDataset trainingDataset = TuningDataset.builder().gcsUri(trainingDatasetUri).build();
62+
63+
String validationDatasetUri =
64+
"gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_validation_data.jsonl";
65+
TuningValidationDataset validationDataset =
66+
TuningValidationDataset.builder().gcsUri(validationDatasetUri).build();
67+
68+
TuningJob tuningJob =
69+
client.tunings.tune(
70+
model,
71+
trainingDataset,
72+
CreateTuningJobConfig.builder()
73+
.tunedModelDisplayName("your-display-name")
74+
.validationDataset(validationDataset)
75+
.build());
76+
77+
String jobName =
78+
tuningJob.name().orElseThrow(() -> new IllegalStateException("Missing job name"));
79+
Optional<JobState> jobState = tuningJob.state();
80+
Set<JobState.Known> runningStates = EnumSet.of(JOB_STATE_PENDING, JOB_STATE_RUNNING);
81+
82+
while (jobState.isPresent() && runningStates.contains(jobState.get().knownEnum())) {
83+
System.out.println("Job state: " + jobState.get());
84+
tuningJob = client.tunings.get(jobName, GetTuningJobConfig.builder().build());
85+
jobState = tuningJob.state();
86+
TimeUnit.SECONDS.sleep(60);
87+
}
88+
89+
tuningJob.tunedModel().flatMap(TunedModel::model).ifPresent(System.out::println);
90+
tuningJob.tunedModel().flatMap(TunedModel::endpoint).ifPresent(System.out::println);
91+
tuningJob.experiment().ifPresent(System.out::println);
92+
// Example response:
93+
// projects/123456789012/locations/us-central1/models/6129850992130260992@1
94+
// projects/123456789012/locations/us-central1/endpoints/105055037499113472
95+
// projects/123456789012/locations/us-central1/metadataStores/default/contexts/experiment_id
96+
97+
List<TunedModelCheckpoint> checkpoints =
98+
tuningJob.tunedModel().flatMap(TunedModel::checkpoints).orElse(Collections.emptyList());
99+
100+
int index = 0;
101+
for (TunedModelCheckpoint checkpoint : checkpoints) {
102+
System.out.println("Checkpoint " + (++index));
103+
checkpoint
104+
.checkpointId()
105+
.ifPresent(checkpointId -> System.out.println("checkpointId=" + checkpointId));
106+
checkpoint.epoch().ifPresent(epoch -> System.out.println("epoch=" + epoch));
107+
checkpoint.step().ifPresent(step -> System.out.println("step=" + step));
108+
checkpoint.endpoint().ifPresent(endpoint -> System.out.println("endpoint=" + endpoint));
109+
}
110+
// Example response:
111+
// Checkpoint 1
112+
// checkpointId=1
113+
// epoch=2
114+
// step=34
115+
// endpoint=projects/project/locations/location/endpoints/105055037499113472
116+
// ...
117+
return jobName;
118+
}
119+
}
120+
}
121+
// [END googlegenaisdk_tuning_job_create]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package genai.tuning;
18+
19+
// [START googlegenaisdk_tuning_job_get]
20+
21+
import com.google.genai.Client;
22+
import com.google.genai.types.GetTuningJobConfig;
23+
import com.google.genai.types.HttpOptions;
24+
import com.google.genai.types.TunedModel;
25+
import com.google.genai.types.TuningJob;
26+
import java.util.Optional;
27+
28+
public class TuningJobGet {
29+
30+
public static void main(String[] args) {
31+
// TODO(developer): Replace these variables before running the sample.
32+
// E.g. tuningJobName =
33+
// "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
34+
String tuningJobName = "your-job-name";
35+
getTuningJob(tuningJobName);
36+
}
37+
38+
// Shows how to get a tuning job
39+
public static Optional<String> getTuningJob(String tuningJobName) {
40+
// Client Initialization. Once created, it can be reused for multiple requests.
41+
try (Client client =
42+
Client.builder()
43+
.location("us-central1")
44+
.vertexAI(true)
45+
.httpOptions(HttpOptions.builder().apiVersion("v1").build())
46+
.build()) {
47+
48+
TuningJob tuningJob = client.tunings.get(tuningJobName, GetTuningJobConfig.builder().build());
49+
50+
tuningJob.tunedModel().flatMap(TunedModel::model).ifPresent(System.out::println);
51+
tuningJob.tunedModel().flatMap(TunedModel::endpoint).ifPresent(System.out::println);
52+
tuningJob.experiment().ifPresent(System.out::println);
53+
// Example response:
54+
// projects/123456789012/locations/us-central1/models/6129850992130260992@1
55+
// projects/123456789012/locations/us-central1/endpoints/105055037499113472
56+
// projects/123456789012/locations/us-central1/metadataStores/default/contexts/experiment_id
57+
return tuningJob.name();
58+
}
59+
}
60+
}
61+
// [END googlegenaisdk_tuning_job_get]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package genai.tuning;
18+
19+
// [START googlegenaisdk_tuning_job_list]
20+
21+
import com.google.genai.Client;
22+
import com.google.genai.Pager;
23+
import com.google.genai.types.HttpOptions;
24+
import com.google.genai.types.ListTuningJobsConfig;
25+
import com.google.genai.types.TuningJob;
26+
27+
public class TuningJobList {
28+
29+
public static void main(String[] args) {
30+
listTuningJob();
31+
}
32+
33+
// Shows how to list the available tuning jobs
34+
public static Pager<TuningJob> listTuningJob() {
35+
// Client Initialization. Once created, it can be reused for multiple requests.
36+
try (Client client =
37+
Client.builder()
38+
.location("us-central1")
39+
.vertexAI(true)
40+
.httpOptions(HttpOptions.builder().apiVersion("v1").build())
41+
.build()) {
42+
43+
Pager<TuningJob> tuningJobs = client.tunings.list(ListTuningJobsConfig.builder().build());
44+
for (TuningJob job : tuningJobs) {
45+
job.name().ifPresent(System.out::println);
46+
// Example response:
47+
// projects/123456789012/locations/us-central1/tuningJobs/329583781566480384
48+
}
49+
50+
return tuningJobs;
51+
}
52+
}
53+
}
54+
// [END googlegenaisdk_tuning_job_list]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package genai.tuning;
18+
19+
// [START googlegenaisdk_tuning_textgen_with_txt]
20+
21+
import com.google.genai.Client;
22+
import com.google.genai.types.GenerateContentConfig;
23+
import com.google.genai.types.GenerateContentResponse;
24+
import com.google.genai.types.GetTuningJobConfig;
25+
import com.google.genai.types.HttpOptions;
26+
import com.google.genai.types.TunedModel;
27+
import com.google.genai.types.TuningJob;
28+
29+
public class TuningTextGenWithTxt {
30+
31+
public static void main(String[] args) {
32+
// TODO(developer): Replace these variables before running the sample.
33+
// E.g. tuningJobName =
34+
// "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
35+
String tuningJobName = "your-job-name";
36+
predictWithTunedEndpoint(tuningJobName);
37+
}
38+
39+
// Shows how to predict with a tuned model endpoint
40+
public static String predictWithTunedEndpoint(String tuningJobName) {
41+
// Client Initialization. Once created, it can be reused for multiple requests.
42+
try (Client client =
43+
Client.builder()
44+
.location("us-central1")
45+
.vertexAI(true)
46+
.httpOptions(HttpOptions.builder().apiVersion("v1").build())
47+
.build()) {
48+
49+
TuningJob tuningJob = client.tunings.get(tuningJobName, GetTuningJobConfig.builder().build());
50+
51+
String endpoint =
52+
tuningJob
53+
.tunedModel()
54+
.flatMap(TunedModel::endpoint)
55+
.orElseThrow(() -> new IllegalStateException("Missing tuned model endpoint"));
56+
57+
GenerateContentResponse response =
58+
client.models.generateContent(
59+
endpoint, "Why is the sky blue?", GenerateContentConfig.builder().build());
60+
61+
System.out.println(response.text());
62+
// Example response:
63+
// The sky is blue because of a phenomenon called Rayleigh scattering...
64+
return response.text();
65+
}
66+
}
67+
}
68+
// [END googlegenaisdk_tuning_textgen_with_txt]

0 commit comments

Comments
 (0)