diff --git a/Sources/_BenchmarkUtilities/PirBenchmarkUtilities.swift b/Sources/_BenchmarkUtilities/PirBenchmarkUtilities.swift index 256e716f..1a0a82ab 100644 --- a/Sources/_BenchmarkUtilities/PirBenchmarkUtilities.swift +++ b/Sources/_BenchmarkUtilities/PirBenchmarkUtilities.swift @@ -115,9 +115,7 @@ public struct PirBenchmarkConfig { extension PrivateInformationRetrieval.Response { func scaledNoiseBudget(using secretKey: Scheme.SecretKey) throws -> Int { - try Int( - noiseBudget(using: secretKey, variableTime: true) * Double( - noiseBudgetScale)) + try Int(noiseBudget(using: secretKey, variableTime: true) * Double(noiseBudgetScale)) } } @@ -178,19 +176,17 @@ public func pirProcessBenchmark( struct IndexPirBenchmarkContext where Server.Scheme == Client.Scheme { + typealias Scheme = Server.Scheme let processedDatabase: Server.Database let server: Server let client: Client - let secretKey: SecretKey - let evaluationKey: Server.Scheme.EvaluationKey - let query: Client.Query + let context: Scheme.Context let evaluationKeySize: Int let evaluationKeyCount: Int let querySize: Int let queryCiphertextCount: Int let responseSize: Int let responseCiphertextCount: Int - let noiseBudget: Int init( server _: Server.Type, @@ -198,9 +194,8 @@ struct IndexPirBenchmarkContext pirConfig: IndexPirConfig, encryptionConfig: EncryptionParametersConfig) async throws { - let encryptParameter: EncryptionParameters = - try EncryptionParameters(from: encryptionConfig) - let context = try Server.Scheme.Context(encryptionParameters: encryptParameter) + let encryptParameter: EncryptionParameters = try EncryptionParameters(from: encryptionConfig) + self.context = try Scheme.Context(encryptionParameters: encryptParameter) let indexPirParameters = Server.generateParameter(config: pirConfig, with: context) let database = getDatabaseForTesting( numberOfEntries: pirConfig.entryCount, @@ -209,9 +204,8 @@ struct IndexPirBenchmarkContext self.server = try Server(parameter: indexPirParameters, context: context, database: processedDatabase) self.client = Client(parameter: indexPirParameters, context: context) - self.secretKey = try context.generateSecretKey() - self.evaluationKey = try client.generateEvaluationKey(using: secretKey) - self.query = try client.generateQuery(at: [0], using: secretKey) + let secretKey = try context.generateSecretKey() + let evaluationKey = try client.generateEvaluationKey(using: secretKey) // Validate correctness let queryIndex = Int.random(in: 0.. self.queryCiphertextCount = query.ciphertexts.count self.responseSize = try response.size() self.responseCiphertextCount = response.ciphertexts.count - self.noiseBudget = try response.scaledNoiseBudget(using: secretKey) } } @@ -238,6 +231,7 @@ public func indexPirBenchmark( // swiftlint:disable:next force_try config: PirBenchmarkConfig = try! .init()) -> () -> Void { + // swiftlint:disable:next closure_body_length { let benchmarkName = [ "IndexPir", @@ -251,27 +245,45 @@ public func indexPirBenchmark( Benchmark(benchmarkName, configuration: config.benchmarkConfig) { ( benchmark, benchmarkContext: IndexPirBenchmarkContext, MulPirClient>) in + let context = benchmarkContext.context for _ in benchmark.scaledIterations { - try await blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query, - using: benchmarkContext - .evaluationKey)) + let secretKey = try context.generateSecretKey() + let evaluationKey = try benchmarkContext.client.generateEvaluationKey(using: secretKey) + let queryIndex = Int.random(in: 0.. = try serializedQuery.native(context: context) + let deserializedEvalKey: PirUtil.Scheme.EvaluationKey = try serializedEvaluationKey + .native(context: context) + let response = try await benchmarkContext.server.computeResponse( + to: deserializedQuery, + using: deserializedEvalKey) + try blackHole(response.proto()) + + benchmark.stopMeasurement() + + let noiseBudget = try response.scaledNoiseBudget(using: secretKey) + benchmark.measurement(.noiseBudget, noiseBudget) } + benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize) benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount) benchmark.measurement(.querySize, benchmarkContext.querySize) benchmark.measurement(.queryCiphertextCount, benchmarkContext.queryCiphertextCount) benchmark.measurement(.responseSize, benchmarkContext.responseSize) benchmark.measurement(.responseCiphertextCount, benchmarkContext.responseCiphertextCount) - benchmark.measurement(.noiseBudget, benchmarkContext.noiseBudget) - } - // swiftlint:enable closure_parameter_position - setup: { + } setup: { try await IndexPirBenchmarkContext( server: MulPirServer.self, client: MulPirClient.self, pirConfig: config.indexPirConfig, encryptionConfig: config.encryptionConfig) } + // swiftlint:enable closure_parameter_position } } @@ -280,23 +292,21 @@ struct KeywordPirBenchmarkContext typealias Client = KeywordPirClient + typealias Scheme = IndexServer.Scheme let server: Server let client: Client - let secretKey: SecretKey - let evaluationKey: Server.Scheme.EvaluationKey - let query: Client.Query + let context: Scheme.Context let evaluationKeySize: Int let evaluationKeyCount: Int let querySize: Int let queryCiphertextCount: Int let responseSize: Int let responseCiphertextCount: Int - let noiseBudget: Int - init(config: PirBenchmarkConfig) async throws { - let encryptParameter: EncryptionParameters = + init(config: PirBenchmarkConfig) async throws { + let encryptParameter: EncryptionParameters = try EncryptionParameters(from: config.encryptionConfig) - let context = try Server.Scheme.Context(encryptionParameters: encryptParameter) + self.context = try Server.Scheme.Context(encryptionParameters: encryptParameter) let rows = (0..( // swiftlint:disable:next force_try config: PirBenchmarkConfig = try! .init()) -> () -> Void { + // swiftlint:disable:next closure_body_length { let benchmarkName = [ "KeywordPir", @@ -380,10 +389,35 @@ public func keywordPirBenchmark( "entrySize=\(config.databaseConfig.entrySizeInBytes)", "keyCompression=\(config.keywordPirConfig.keyCompression)", ].joined(separator: "/") - Benchmark(benchmarkName, configuration: config.benchmarkConfig) { benchmark, benchmarkContext in + // swiftlint:disable closure_parameter_position + Benchmark(benchmarkName, configuration: config.benchmarkConfig) { ( + benchmark, + benchmarkContext: KeywordPirBenchmarkContext, MulPirClient>) in + let context = benchmarkContext.context for _ in benchmark.scaledIterations { - try await blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query, - using: benchmarkContext.evaluationKey)) + let secretKey = try context.generateSecretKey() + let evaluationKey = try benchmarkContext.client.generateEvaluationKey(using: secretKey) + let queryIndex = Int.random(in: 0.. = try serializedQuery.native(context: context) + let deserializedEvalKey: PirUtil.Scheme.EvaluationKey = try serializedEvaluationKey + .native(context: context) + let response = try await benchmarkContext.server.computeResponse( + to: deserializedQuery, + using: deserializedEvalKey) + try blackHole(response.proto()) + + benchmark.stopMeasurement() + + let noiseBudget = try response.scaledNoiseBudget(using: secretKey) + benchmark.measurement(.noiseBudget, noiseBudget) } benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize) benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount) @@ -391,10 +425,10 @@ public func keywordPirBenchmark( benchmark.measurement(.queryCiphertextCount, benchmarkContext.queryCiphertextCount) benchmark.measurement(.responseSize, benchmarkContext.responseSize) benchmark.measurement(.responseCiphertextCount, benchmarkContext.responseCiphertextCount) - benchmark.measurement(.noiseBudget, benchmarkContext.noiseBudget) } setup: { try await KeywordPirBenchmarkContext, MulPirClient>( config: config) } + // swiftlint:enable closure_parameter_position } } diff --git a/Sources/_BenchmarkUtilities/PnnsBenchmarkUtilities.swift b/Sources/_BenchmarkUtilities/PnnsBenchmarkUtilities.swift index 3655b7a9..0c211c97 100644 --- a/Sources/_BenchmarkUtilities/PnnsBenchmarkUtilities.swift +++ b/Sources/_BenchmarkUtilities/PnnsBenchmarkUtilities.swift @@ -131,6 +131,7 @@ public func cosineSimilarityBenchmark(_: Scheme.Type, config: PnnsBenchmarkConfig = try! .init(), queryCount: Int = 1) -> () -> Void { + // swiftlint:disable:next closure_body_length { let benchmarkName = [ "CosineSimilarity", @@ -145,11 +146,32 @@ public func cosineSimilarityBenchmark(_: Scheme.Type, Benchmark(benchmarkName, configuration: config.benchmarkConfig) { ( benchmark, benchmarkContext: PnnsBenchmarkContext) in + let context = benchmarkContext.server.contexts[0] + let vectorDimension = benchmarkContext.server.config.vectorDimension for _ in benchmark.scaledIterations { - try await blackHole( - benchmarkContext.server.computeResponse( - to: benchmarkContext.query, - using: benchmarkContext.evaluationKey)) + let secretKey = try context.generateSecretKey() + let evaluationKey = try benchmarkContext.client.generateEvaluationKey(using: secretKey) + let serializedEvaluationKey = evaluationKey.serialize().proto() + let data = getDatabaseForTesting(config: PnnsDatabaseConfig( + rowCount: queryCount, + vectorDimension: vectorDimension)) + let queryVectors = Array2d(data: data.rows.map { row in row.vector }) + let query = try benchmarkContext.client.generateQuery(for: queryVectors, using: secretKey) + let serializedQuery = try query.proto() + + benchmark.startMeasurement() + + let deserializedEvalKey: EvaluationKey = try serializedEvaluationKey.native(context: context) + let deserializedQuery: Query = try serializedQuery.native(context: context) + let response = try await benchmarkContext.server.computeResponse( + to: deserializedQuery, + using: deserializedEvalKey) + try blackHole(response.proto()) + + benchmark.stopMeasurement() + + let noiseBudget = try response.scaledNoiseBudget(using: secretKey) + benchmark.measurement(.noiseBudget, noiseBudget) } benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize) benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount) @@ -157,7 +179,6 @@ public func cosineSimilarityBenchmark(_: Scheme.Type, benchmark.measurement(.queryCiphertextCount, benchmarkContext.queryCiphertextCount) benchmark.measurement(.responseSize, benchmarkContext.responseSize) benchmark.measurement(.responseCiphertextCount, benchmarkContext.responseCiphertextCount) - benchmark.measurement(.noiseBudget, benchmarkContext.noiseBudget) } setup: { try await PnnsBenchmarkContext( databaseConfig: config.databaseConfig, @@ -236,16 +257,13 @@ struct PnnsBenchmarkContext { let processedDatabase: ProcessedDatabase let server: Server let client: Client - let secretKey: SecretKey - let evaluationKey: Scheme.EvaluationKey + let contexts: [Scheme.Context] let evaluationKeyCount: Int - let query: Query let evaluationKeySize: Int let querySize: Int let queryCiphertextCount: Int let responseSize: Int let responseCiphertextCount: Int - let noiseBudget: Int init(databaseConfig: PnnsDatabaseConfig, encryptionConfig: EncryptionParametersConfig, @@ -293,18 +311,18 @@ struct PnnsBenchmarkContext { databasePacking: .diagonal(babyStepGiantStep: babyStepGiantStep)) let database = getDatabaseForTesting(config: databaseConfig) - let contexts = try clientConfig.encryptionParameters + self.contexts = try clientConfig.encryptionParameters .map { encryptionParameters in try Scheme.Context(encryptionParameters: encryptionParameters) } self.processedDatabase = try await database.process(config: serverConfig, contexts: contexts) self.client = try Client(config: clientConfig, contexts: contexts) self.server = try Server(database: processedDatabase) - self.secretKey = try client.generateSecretKey() - self.evaluationKey = try client.generateEvaluationKey(using: secretKey) + let secretKey = try client.generateSecretKey() + let evaluationKey = try client.generateEvaluationKey(using: secretKey) // We query exact matches from rows in the database let databaseVectors = Array2d(data: database.rows.map { row in row.vector }) let queryVectors = Array2d(data: database.rows.prefix(queryCount).map { row in row.vector }) - self.query = try client.generateQuery(for: queryVectors, using: secretKey) + let query = try client.generateQuery(for: queryVectors, using: secretKey) let response = try await server.computeResponse(to: query, using: evaluationKey) let decrypted = try client.decrypt(response: response, using: secretKey) @@ -324,6 +342,5 @@ struct PnnsBenchmarkContext { self.responseSize = try response.size() self.responseCiphertextCount = response.ciphertextMatrices .map { matrix in matrix.ciphertexts.count }.sum() - self.noiseBudget = try response.scaledNoiseBudget(using: secretKey) } }