Skip to content

Commit

Permalink
Add system instruction support (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Apr 10, 2024
1 parent 2063447 commit 3d26504
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Sources/GoogleAI/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct GenerateContentRequest {
let safetySettings: [SafetySetting]?
let tools: [Tool]?
let toolConfig: ToolConfig?
let systemInstruction: ModelContent?
let isStreaming: Bool
let options: RequestOptions
}
Expand All @@ -35,6 +36,7 @@ extension GenerateContentRequest: Encodable {
case safetySettings
case tools
case toolConfig
case systemInstruction
}
}

Expand Down
11 changes: 11 additions & 0 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public final class GenerativeModel {
/// Tool configuration for any `Tool` specified in the request.
let toolConfig: ToolConfig?

/// Instructions that direct the model to behave a certain way.
let systemInstruction: ModelContent?

/// Configuration parameters for sending requests to the backend.
let requestOptions: RequestOptions

Expand All @@ -51,6 +54,8 @@ public final class GenerativeModel {
/// - generationConfig: The content generation parameters your model should use.
/// - safetySettings: A value describing what types of harmful content your model should allow.
/// - tools: A list of ``Tool`` objects that the model may use to generate the next response.
/// - systemInstruction: Instructions that direct the model to behave a certain way; currently
/// only text content is supported.
/// - toolConfig: Tool configuration for any `Tool` specified in the request.
/// - requestOptions Configuration parameters for sending requests to the backend.
public convenience init(name: String,
Expand All @@ -59,6 +64,7 @@ public final class GenerativeModel {
safetySettings: [SafetySetting]? = nil,
tools: [Tool]? = nil,
toolConfig: ToolConfig? = nil,
systemInstruction: ModelContent? = nil,
requestOptions: RequestOptions = RequestOptions()) {
self.init(
name: name,
Expand All @@ -67,6 +73,7 @@ public final class GenerativeModel {
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
requestOptions: requestOptions,
urlSession: .shared
)
Expand All @@ -79,6 +86,7 @@ public final class GenerativeModel {
safetySettings: [SafetySetting]? = nil,
tools: [Tool]? = nil,
toolConfig: ToolConfig? = nil,
systemInstruction: ModelContent? = nil,
requestOptions: RequestOptions = RequestOptions(),
urlSession: URLSession) {
modelResourceName = GenerativeModel.modelResourceName(name: name)
Expand All @@ -87,6 +95,7 @@ public final class GenerativeModel {
self.safetySettings = safetySettings
self.tools = tools
self.toolConfig = toolConfig
self.systemInstruction = systemInstruction
self.requestOptions = requestOptions

Logging.default.info("""
Expand Down Expand Up @@ -134,6 +143,7 @@ public final class GenerativeModel {
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
isStreaming: false,
options: requestOptions)
response = try await generativeAIService.loadRequest(request: generateContentRequest)
Expand Down Expand Up @@ -207,6 +217,7 @@ public final class GenerativeModel {
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
isStreaming: true,
options: requestOptions)

Expand Down
9 changes: 8 additions & 1 deletion Tests/GoogleAITests/GoogleAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,24 @@ final class GoogleGenerativeAITests: XCTestCase {
maxOutputTokens: 256,
stopSequences: ["..."])
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]
let systemInstruction = ModelContent(role: "system", parts: [.text("Talk like a pirate.")])

// Permutations without optional arguments.
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY")
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY", safetySettings: filters)
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY", generationConfig: config)
let _ = GenerativeModel(
name: "gemini-1.0-pro",
apiKey: "API_KEY",
systemInstruction: systemInstruction
)

// All arguments passed.
let genAI = GenerativeModel(name: "gemini-1.0-pro",
apiKey: "API_KEY",
generationConfig: config, // Optional
safetySettings: filters // Optional
safetySettings: filters, // Optional
systemInstruction: systemInstruction // Optional
)
// Full Typed Usage
let pngData = Data() // ....
Expand Down

0 comments on commit 3d26504

Please sign in to comment.