Skip to content

Commit

Permalink
Add creation with update methods that polls the state
Browse files Browse the repository at this point in the history
  • Loading branch information
rudrankriyam committed Oct 13, 2024
1 parent 73d9d94 commit ec0363d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 7 deletions.
81 changes: 80 additions & 1 deletion Sources/ShipinKit/LumaAI/LumaAIClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
//

import Foundation
#if canImport(BackgroundTasks)
import BackgroundTasks
#endif

/// A client for interacting with the Luma AI API.
public class LumaAIClient {
actor LumaAIClient {
private let apiKey: String
private let session: URLSession
private let baseURL = URL(string: "https://api.lumalabs.ai")!

private var generationTasks: [String: Task<Void, Error>] = [:]

/// Initializes a new instance of `LumaAIClient`
///
/// - Parameters:
Expand Down Expand Up @@ -153,4 +158,78 @@ public class LumaAIClient {
throw LumaAIError.httpError(statusCode: httpResponse.statusCode)
}
}

/// Retrieves a list of supported camera motions from the Luma AI API.
///
/// - Returns: An array of strings representing supported camera motions.
///
/// - Throws: `LumaAIError.httpError` if the API request fails, or `LumaAIError.decodingError` if the response cannot be decoded.
public func listCameraMotions() async throws -> [String] {
let url = baseURL.appendingPathComponent("/dream-machine/v1/generations/camera_motion/list")
var request = URLRequest(url: url)
request.httpMethod = "GET"
request.addValue("application/json", forHTTPHeaderField: "accept")
request.addValue("Bearer \(apiKey)", forHTTPHeaderField: "authorization")

let (data, response) = try await session.data(for: request)

if let httpResponse = response as? HTTPURLResponse, !(200...299).contains(httpResponse.statusCode) {
throw LumaAIError.httpError(statusCode: httpResponse.statusCode)
}

let decoder = JSONDecoder()
do {
let cameraMotions = try decoder.decode([String].self, from: data)
return cameraMotions
} catch {
throw LumaAIError.decodingError(underlying: error)
}
}

public func createGenerationWithUpdates(prompt: String, aspectRatio: String = "16:9", loop: Bool, keyframes: [String: LumaAIKeyframeData]) async throws {
let initialResponse = try await createGeneration(prompt: prompt, aspectRatio: aspectRatio, loop: loop, keyframes: keyframes)

let task = Task<Void, Error> {
var currentResponse = initialResponse

while currentResponse.state != "completed" && currentResponse.state != "failed" {
try await Task.sleep(for: .seconds(5))
currentResponse = try await self.checkGenerationStatus(id: currentResponse.id)
}
}

self.generationTasks[initialResponse.id] = task

do {
try await task.value
} catch {
self.generationTasks.removeValue(forKey: initialResponse.id)
throw error
}

self.generationTasks.removeValue(forKey: initialResponse.id)
}

private func checkGenerationStatus(id: String) async throws -> LumaAIGenerationResponse {
let url = baseURL.appendingPathComponent("/dream-machine/v1/generations/\(id)")
var request = URLRequest(url: url)
request.httpMethod = "GET"
request.addValue("application/json", forHTTPHeaderField: "accept")
request.addValue("Bearer \(apiKey)", forHTTPHeaderField: "authorization")

let (data, response) = try await session.data(for: request)

guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else {
throw LumaAIError.invalidResponse
}

let decoder = JSONDecoder()
decoder.keyDecodingStrategy = .convertFromSnakeCase
return try decoder.decode(LumaAIGenerationResponse.self, from: data)
}

public func cancelGenerationUpdates(id: String) {
generationTasks[id]?.cancel()
generationTasks.removeValue(forKey: id)
}
}
7 changes: 6 additions & 1 deletion Sources/ShipinKit/LumaAI/LumaAIError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
// Created by Rudrank Riyam on 10/13/24.
//


import Foundation

/// An error type representing errors from the Luma AI client.
public enum LumaAIError: Error {

/// An HTTP error with a status code.
case httpError(statusCode: Int)

/// A decoding error occurred.
case decodingError(underlying: Error)

/// An invalid response from the Luma AI server.
case invalidResponse
}
10 changes: 5 additions & 5 deletions Sources/ShipinKit/LumaAI/LumaAIGenerationResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import Foundation

/// Represents the response from the Luma AI generation API.
public struct LumaAIGenerationResponse: Codable {
public struct LumaAIGenerationResponse: Codable, Sendable {
public let id: String
public let state: String
public let failureReason: String?
Expand All @@ -29,12 +29,12 @@ public struct LumaAIGenerationResponse: Codable {
}

/// Contains the assets returned by the Luma AI generation API.
public struct LumaAIAssets: Codable {
public struct LumaAIAssets: Codable, Sendable {
public let video: String
}

/// Represents the original request sent to the Luma AI generation API.
public struct LumaAIGenerationRequest: Codable {
public struct LumaAIGenerationRequest: Codable, Sendable {
public let prompt: String
public let aspectRatio: String
public let loop: Bool
Expand All @@ -51,14 +51,14 @@ public struct LumaAIGenerationRequest: Codable {
}

/// Represents keyframe data in the generation request.
public struct LumaAIKeyframeData: Codable {
public struct LumaAIKeyframeData: Codable, Sendable {
public let type: LumaAIKeyframeType
public let url: String?
public let id: String?
}

/// Represents the type of keyframe in the generation request.
public enum LumaAIKeyframeType: String, Codable {
public enum LumaAIKeyframeType: String, Codable, Sendable {
case generation
case image
}

0 comments on commit ec0363d

Please sign in to comment.