Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 70 additions & 36 deletions Sources/_BenchmarkUtilities/PirBenchmarkUtilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ public struct PirBenchmarkConfig<Scalar: ScalarType> {

extension PrivateInformationRetrieval.Response {
func scaledNoiseBudget(using secretKey: Scheme.SecretKey) throws -> Int {
try Int(
noiseBudget(using: secretKey, variableTime: true) * Double(
noiseBudgetScale))
try Int(noiseBudget(using: secretKey, variableTime: true) * Double(noiseBudgetScale))
}
}

Expand Down Expand Up @@ -178,29 +176,26 @@ public func pirProcessBenchmark<PirUtil: PirUtilProtocol>(
struct IndexPirBenchmarkContext<Server: IndexPirServer, Client: IndexPirClient>
where Server.Scheme == Client.Scheme
{
typealias Scheme = Server.Scheme
let processedDatabase: Server.Database
let server: Server
let client: Client
let secretKey: SecretKey<Client.Scheme>
let evaluationKey: Server.Scheme.EvaluationKey
let query: Client.Query
let context: Scheme.Context
let evaluationKeySize: Int
let evaluationKeyCount: Int
let querySize: Int
let queryCiphertextCount: Int
let responseSize: Int
let responseCiphertextCount: Int
let noiseBudget: Int

init(
server _: Server.Type,
client _: Client.Type,
pirConfig: IndexPirConfig,
encryptionConfig: EncryptionParametersConfig) async throws
{
let encryptParameter: EncryptionParameters<Server.Scheme.Scalar> =
try EncryptionParameters(from: encryptionConfig)
let context = try Server.Scheme.Context(encryptionParameters: encryptParameter)
let encryptParameter: EncryptionParameters<Scheme.Scalar> = try EncryptionParameters(from: encryptionConfig)
self.context = try Scheme.Context(encryptionParameters: encryptParameter)
let indexPirParameters = Server.generateParameter(config: pirConfig, with: context)
let database = getDatabaseForTesting(
numberOfEntries: pirConfig.entryCount,
Expand All @@ -209,9 +204,8 @@ struct IndexPirBenchmarkContext<Server: IndexPirServer, Client: IndexPirClient>

self.server = try Server(parameter: indexPirParameters, context: context, database: processedDatabase)
self.client = Client(parameter: indexPirParameters, context: context)
self.secretKey = try context.generateSecretKey()
self.evaluationKey = try client.generateEvaluationKey(using: secretKey)
self.query = try client.generateQuery(at: [0], using: secretKey)
let secretKey = try context.generateSecretKey()
let evaluationKey = try client.generateEvaluationKey(using: secretKey)

// Validate correctness
let queryIndex = Int.random(in: 0..<pirConfig.entryCount)
Expand All @@ -228,7 +222,6 @@ struct IndexPirBenchmarkContext<Server: IndexPirServer, Client: IndexPirClient>
self.queryCiphertextCount = query.ciphertexts.count
self.responseSize = try response.size()
self.responseCiphertextCount = response.ciphertexts.count
self.noiseBudget = try response.scaledNoiseBudget(using: secretKey)
}
}

Expand All @@ -238,6 +231,7 @@ public func indexPirBenchmark<PirUtil: PirUtilProtocol>(
// swiftlint:disable:next force_try
config: PirBenchmarkConfig<PirUtil.Scheme.Scalar> = try! .init()) -> () -> Void
{
// swiftlint:disable:next closure_body_length
{
let benchmarkName = [
"IndexPir",
Expand All @@ -251,27 +245,45 @@ public func indexPirBenchmark<PirUtil: PirUtilProtocol>(
Benchmark(benchmarkName, configuration: config.benchmarkConfig) { (
benchmark,
benchmarkContext: IndexPirBenchmarkContext<MulPirServer<PirUtil>, MulPirClient<PirUtil>>) in
let context = benchmarkContext.context
for _ in benchmark.scaledIterations {
try await blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query,
using: benchmarkContext
.evaluationKey))
let secretKey = try context.generateSecretKey()
let evaluationKey = try benchmarkContext.client.generateEvaluationKey(using: secretKey)
let queryIndex = Int.random(in: 0..<benchmarkContext.server.parameter.entryCount)
let query = try benchmarkContext.client.generateQuery(at: [queryIndex], using: secretKey)
let serializedQuery = try query.proto()
let serializedEvaluationKey = evaluationKey.serialize().proto()

benchmark.startMeasurement()

let deserializedQuery: Query<PirUtil.Scheme> = try serializedQuery.native(context: context)
let deserializedEvalKey: PirUtil.Scheme.EvaluationKey = try serializedEvaluationKey
.native(context: context)
let response = try await benchmarkContext.server.computeResponse(
to: deserializedQuery,
using: deserializedEvalKey)
try blackHole(response.proto())

benchmark.stopMeasurement()

let noiseBudget = try response.scaledNoiseBudget(using: secretKey)
benchmark.measurement(.noiseBudget, noiseBudget)
}

benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize)
benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount)
benchmark.measurement(.querySize, benchmarkContext.querySize)
benchmark.measurement(.queryCiphertextCount, benchmarkContext.queryCiphertextCount)
benchmark.measurement(.responseSize, benchmarkContext.responseSize)
benchmark.measurement(.responseCiphertextCount, benchmarkContext.responseCiphertextCount)
benchmark.measurement(.noiseBudget, benchmarkContext.noiseBudget)
}
// swiftlint:enable closure_parameter_position
setup: {
} setup: {
try await IndexPirBenchmarkContext(
server: MulPirServer<PirUtil>.self,
client: MulPirClient<PirUtil>.self,
pirConfig: config.indexPirConfig,
encryptionConfig: config.encryptionConfig)
}
// swiftlint:enable closure_parameter_position
}
}

Expand All @@ -280,23 +292,21 @@ struct KeywordPirBenchmarkContext<IndexServer: IndexPirServer, IndexClient: Inde
{
typealias Server = KeywordPirServer<IndexServer>
typealias Client = KeywordPirClient<IndexClient>
typealias Scheme = IndexServer.Scheme
let server: Server
let client: Client
let secretKey: SecretKey<Client.Scheme>
let evaluationKey: Server.Scheme.EvaluationKey
let query: Client.Query
let context: Scheme.Context
let evaluationKeySize: Int
let evaluationKeyCount: Int
let querySize: Int
let queryCiphertextCount: Int
let responseSize: Int
let responseCiphertextCount: Int
let noiseBudget: Int

init(config: PirBenchmarkConfig<Server.Scheme.Scalar>) async throws {
let encryptParameter: EncryptionParameters<Server.Scheme.Scalar> =
init(config: PirBenchmarkConfig<Scheme.Scalar>) async throws {
let encryptParameter: EncryptionParameters<Scheme.Scalar> =
try EncryptionParameters(from: config.encryptionConfig)
let context = try Server.Scheme.Context(encryptionParameters: encryptParameter)
self.context = try Server.Scheme.Context(encryptionParameters: encryptParameter)
let rows = (0..<config.databaseConfig.entryCount).map { index in KeywordValuePair(
keyword: [UInt8](String(index).utf8),
value: (0..<config.databaseConfig.entrySizeInBytes).map { _ in UInt8.random(in: 0..<UInt8.max) })
Expand Down Expand Up @@ -336,9 +346,8 @@ struct KeywordPirBenchmarkContext<IndexServer: IndexPirServer, IndexClient: Inde
keywordParameter: keywordPirConfig.parameter,
pirParameter: processed.pirParameter,
context: context)
self.secretKey = try context.generateSecretKey()
self.evaluationKey = try client.generateEvaluationKey(using: secretKey)
self.query = try client.generateQuery(at: [UInt8]("0".utf8), using: secretKey)
let secretKey = try context.generateSecretKey()
let evaluationKey = try client.generateEvaluationKey(using: secretKey)

// Validate correctness
let queryIndex = Int.random(in: 0..<config.databaseConfig.entryCount)
Expand All @@ -361,7 +370,6 @@ struct KeywordPirBenchmarkContext<IndexServer: IndexPirServer, IndexClient: Inde
self.queryCiphertextCount = query.ciphertexts.count
self.responseSize = try response.size()
self.responseCiphertextCount = response.ciphertexts.count
self.noiseBudget = try response.scaledNoiseBudget(using: secretKey)
}
}

Expand All @@ -371,6 +379,7 @@ public func keywordPirBenchmark<PirUtil: PirUtilProtocol>(
// swiftlint:disable:next force_try
config: PirBenchmarkConfig<PirUtil.Scheme.Scalar> = try! .init()) -> () -> Void
{
// swiftlint:disable:next closure_body_length
{
let benchmarkName = [
"KeywordPir",
Expand All @@ -380,21 +389,46 @@ public func keywordPirBenchmark<PirUtil: PirUtilProtocol>(
"entrySize=\(config.databaseConfig.entrySizeInBytes)",
"keyCompression=\(config.keywordPirConfig.keyCompression)",
].joined(separator: "/")
Benchmark(benchmarkName, configuration: config.benchmarkConfig) { benchmark, benchmarkContext in
// swiftlint:disable closure_parameter_position
Benchmark(benchmarkName, configuration: config.benchmarkConfig) { (
benchmark,
benchmarkContext: KeywordPirBenchmarkContext<MulPirServer<PirUtil>, MulPirClient<PirUtil>>) in
let context = benchmarkContext.context
for _ in benchmark.scaledIterations {
try await blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query,
using: benchmarkContext.evaluationKey))
let secretKey = try context.generateSecretKey()
let evaluationKey = try benchmarkContext.client.generateEvaluationKey(using: secretKey)
let queryIndex = Int.random(in: 0..<config.databaseConfig.entryCount)
let query = try benchmarkContext.client.generateQuery(
at: [UInt8](String(describing: queryIndex).utf8),
using: secretKey)
let serializedQuery = try query.proto()
let serializedEvaluationKey = evaluationKey.serialize().proto()

benchmark.startMeasurement()

let deserializedQuery: Query<PirUtil.Scheme> = try serializedQuery.native(context: context)
let deserializedEvalKey: PirUtil.Scheme.EvaluationKey = try serializedEvaluationKey
.native(context: context)
let response = try await benchmarkContext.server.computeResponse(
to: deserializedQuery,
using: deserializedEvalKey)
try blackHole(response.proto())

benchmark.stopMeasurement()

let noiseBudget = try response.scaledNoiseBudget(using: secretKey)
benchmark.measurement(.noiseBudget, noiseBudget)
}
benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize)
benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount)
benchmark.measurement(.querySize, benchmarkContext.querySize)
benchmark.measurement(.queryCiphertextCount, benchmarkContext.queryCiphertextCount)
benchmark.measurement(.responseSize, benchmarkContext.responseSize)
benchmark.measurement(.responseCiphertextCount, benchmarkContext.responseCiphertextCount)
benchmark.measurement(.noiseBudget, benchmarkContext.noiseBudget)
} setup: {
try await KeywordPirBenchmarkContext<MulPirServer<PirUtil>, MulPirClient<PirUtil>>(
config: config)
}
// swiftlint:enable closure_parameter_position
}
}
45 changes: 31 additions & 14 deletions Sources/_BenchmarkUtilities/PnnsBenchmarkUtilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ public func cosineSimilarityBenchmark<Scheme: HeScheme>(_: Scheme.Type,
config: PnnsBenchmarkConfig = try! .init(),
queryCount: Int = 1) -> () -> Void
{
// swiftlint:disable:next closure_body_length
{
let benchmarkName = [
"CosineSimilarity",
Expand All @@ -145,19 +146,39 @@ public func cosineSimilarityBenchmark<Scheme: HeScheme>(_: Scheme.Type,
Benchmark(benchmarkName, configuration: config.benchmarkConfig) { (
benchmark,
benchmarkContext: PnnsBenchmarkContext<Scheme>) in
let context = benchmarkContext.server.contexts[0]
let vectorDimension = benchmarkContext.server.config.vectorDimension
for _ in benchmark.scaledIterations {
try await blackHole(
benchmarkContext.server.computeResponse(
to: benchmarkContext.query,
using: benchmarkContext.evaluationKey))
let secretKey = try context.generateSecretKey()
let evaluationKey = try benchmarkContext.client.generateEvaluationKey(using: secretKey)
let serializedEvaluationKey = evaluationKey.serialize().proto()
let data = getDatabaseForTesting(config: PnnsDatabaseConfig(
rowCount: queryCount,
vectorDimension: vectorDimension))
let queryVectors = Array2d(data: data.rows.map { row in row.vector })
let query = try benchmarkContext.client.generateQuery(for: queryVectors, using: secretKey)
let serializedQuery = try query.proto()

benchmark.startMeasurement()

let deserializedEvalKey: EvaluationKey<Scheme> = try serializedEvaluationKey.native(context: context)
let deserializedQuery: Query<Scheme> = try serializedQuery.native(context: context)
let response = try await benchmarkContext.server.computeResponse(
to: deserializedQuery,
using: deserializedEvalKey)
try blackHole(response.proto())

benchmark.stopMeasurement()

let noiseBudget = try response.scaledNoiseBudget(using: secretKey)
benchmark.measurement(.noiseBudget, noiseBudget)
}
benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize)
benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount)
benchmark.measurement(.querySize, benchmarkContext.querySize)
benchmark.measurement(.queryCiphertextCount, benchmarkContext.queryCiphertextCount)
benchmark.measurement(.responseSize, benchmarkContext.responseSize)
benchmark.measurement(.responseCiphertextCount, benchmarkContext.responseCiphertextCount)
benchmark.measurement(.noiseBudget, benchmarkContext.noiseBudget)
} setup: {
try await PnnsBenchmarkContext<Scheme>(
databaseConfig: config.databaseConfig,
Expand Down Expand Up @@ -236,16 +257,13 @@ struct PnnsBenchmarkContext<Scheme: HeScheme> {
let processedDatabase: ProcessedDatabase<Scheme>
let server: Server<Scheme>
let client: Client<Scheme>
let secretKey: SecretKey<Scheme>
let evaluationKey: Scheme.EvaluationKey
let contexts: [Scheme.Context]
let evaluationKeyCount: Int
let query: Query<Scheme>
let evaluationKeySize: Int
let querySize: Int
let queryCiphertextCount: Int
let responseSize: Int
let responseCiphertextCount: Int
let noiseBudget: Int

init(databaseConfig: PnnsDatabaseConfig,
encryptionConfig: EncryptionParametersConfig,
Expand Down Expand Up @@ -293,18 +311,18 @@ struct PnnsBenchmarkContext<Scheme: HeScheme> {
databasePacking: .diagonal(babyStepGiantStep: babyStepGiantStep))

let database = getDatabaseForTesting(config: databaseConfig)
let contexts = try clientConfig.encryptionParameters
self.contexts = try clientConfig.encryptionParameters
.map { encryptionParameters in try Scheme.Context(encryptionParameters: encryptionParameters) }
self.processedDatabase = try await database.process(config: serverConfig, contexts: contexts)
self.client = try Client(config: clientConfig, contexts: contexts)
self.server = try Server(database: processedDatabase)
self.secretKey = try client.generateSecretKey()
self.evaluationKey = try client.generateEvaluationKey(using: secretKey)
let secretKey = try client.generateSecretKey()
let evaluationKey = try client.generateEvaluationKey(using: secretKey)

// We query exact matches from rows in the database
let databaseVectors = Array2d(data: database.rows.map { row in row.vector })
let queryVectors = Array2d(data: database.rows.prefix(queryCount).map { row in row.vector })
self.query = try client.generateQuery(for: queryVectors, using: secretKey)
let query = try client.generateQuery(for: queryVectors, using: secretKey)

let response = try await server.computeResponse(to: query, using: evaluationKey)
let decrypted = try client.decrypt(response: response, using: secretKey)
Expand All @@ -324,6 +342,5 @@ struct PnnsBenchmarkContext<Scheme: HeScheme> {
self.responseSize = try response.size()
self.responseCiphertextCount = response.ciphertextMatrices
.map { matrix in matrix.ciphertexts.count }.sum()
self.noiseBudget = try response.scaledNoiseBudget(using: secretKey)
}
}