Skip to content

Commit

Permalink
Update ShipinKit
Browse files Browse the repository at this point in the history
  • Loading branch information
rudrankriyam committed Oct 14, 2024
1 parent 8d5493c commit 046942e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 53 deletions.
2 changes: 1 addition & 1 deletion Sources/ShipinKit/RunwayML/RunwayML.swift
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ public actor RunwayML {

return try await pollTaskStatus(id: taskID)
}

private func pollTaskStatus(id: String) async throws -> URL {
logger.debug("Starting to poll task status for ID: \(id)")

Expand Down
101 changes: 49 additions & 52 deletions Sources/ShipinKit/ShipinKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,79 +13,76 @@ public enum AIService {
case runwayML(apiKey: String)
}

/// ShipinKit: A unified interface for either LumaAI or RunwayML client.
public struct ShipinKit {
/// ShipinKit: A unified interface for any AI service client.
public actor ShipinKit {

/// The chosen AI service client.
private let service: Any

/// The type of AI service being used.
public let serviceType: AIService

private let service: AIService

/// Initializes a new instance of ShipinKit.
///
/// - Parameter service: The AI service to use, either LumaAI or RunwayML.
/// - Parameter service: The AI service to use. This can be any service that conforms to the `AIService` protocol.
///
/// - Returns: A new instance of ShipinKit.
public init(service: AIService) {
self.serviceType = service
switch service {
case .lumaAI(let apiKey):
self.service = LumaAI(apiKey: apiKey)
case .runwayML(let apiKey):
self.service = RunwayML(apiKey: apiKey)
}
self.service = service
}

/// Generates content using the chosen AI service.
///
/// This method adapts to the chosen AI service (LumaAI or RunwayML) and generates content based on the provided parameters.
///
/// - Parameters:
/// - prompt: The prompt for generation.
/// - aspectRatio: The aspect ratio of the generated content.
/// - loop: Whether the generated content should loop (only applicable for LumaAI).
/// - keyframes: A dictionary of keyframes (only applicable for LumaAI).
/// - image: A `ShipinImage` object representing the input image (only applicable for RunwayML).
/// - duration: The desired duration of the video (only applicable for RunwayML).
/// - watermark: A boolean indicating whether to include a watermark (only applicable for RunwayML).
/// - seed: An optional integer seed for reproducible results (only applicable for RunwayML).
/// - prompt: A `String` representing the prompt for generation.
/// - aspectRatio: A `String` specifying the aspect ratio of the generated content. Defaults to "16:9".
/// - loop: An optional `Bool` indicating whether the generated content should loop (only applicable for LumaAI).
/// - keyframes: An optional dictionary of keyframes (only applicable for LumaAI).
/// - image: An optional `@Sendable` closure returning a `ShipinImage` object representing the input image (only applicable for RunwayML).
/// - duration: An optional `RunwayMLVideoDuration` specifying the desired duration of the video (only applicable for RunwayML).
/// - watermark: An optional `Bool` indicating whether to include a watermark (only applicable for RunwayML).
/// - seed: An optional `Int` seed for reproducible results (only applicable for RunwayML).
///
/// - Returns: Either a `LumaAIGenerationResponse` or a `URL`, depending on the chosen service.
/// - Throws: An error if the generation fails or if incompatible parameters are provided for the chosen service.
/// - Throws: A `ShipinKitError` if the generation fails or if incompatible parameters are provided for the chosen service.
public func generate(
prompt: String,
aspectRatio: String = "16:9",
loop: Bool? = nil,
keyframes: [String: LumaAIKeyframeData]? = nil,
image: ShipinImage? = nil,
image: (@Sendable () -> ShipinImage)? = nil,
duration: RunwayMLVideoDuration? = nil,
watermark: Bool? = nil,
seed: Int? = nil
) async throws -> Any {
switch serviceType {
case .lumaAI:
guard let lumaAI = service as? LumaAI,
let loop = loop,
let keyframes = keyframes else {
throw ShipinKitError.invalidParameters
}
return try await lumaAI.createGeneration(
prompt: prompt,
aspectRatio: aspectRatio,
loop: loop,
keyframes: keyframes
)
case .runwayML:
guard let runwayML = service as? RunwayML,
let image = image else {
throw ShipinKitError.invalidParameters
}
return try await runwayML.generateVideo(
prompt: prompt,
image: image,
duration: duration ?? .short,
aspectRatio: RunwayMLAspectRatio(rawValue: aspectRatio) ?? .widescreen,
watermark: watermark ?? false,
seed: seed
)
switch service {
case .lumaAI(let apiKey):
let lumaAI = LumaAI(apiKey: apiKey)

guard let loop = loop, let keyframes = keyframes else {
throw ShipinKitError.invalidParameters
}

return try await lumaAI.createGeneration(
prompt: prompt,
aspectRatio: aspectRatio,
loop: loop,
keyframes: keyframes
)
case .runwayML(let apiKey):
let runwayML = RunwayML(apiKey: apiKey)

guard let imageClosure = image else {
throw ShipinKitError.invalidParameters
}

return try await runwayML.generateVideo(
prompt: prompt,
image: imageClosure(),
duration: duration ?? .short,
aspectRatio: RunwayMLAspectRatio(rawValue: aspectRatio) ?? .widescreen,
watermark: watermark ?? false,
seed: seed
)
}
}
}
Expand Down

0 comments on commit 046942e

Please sign in to comment.