diff --git a/.swiftlint.yml b/.swiftlint.yml index e55e445c..d703b9f7 100644 --- a/.swiftlint.yml +++ b/.swiftlint.yml @@ -17,6 +17,9 @@ excluded: - "**/DoubleWidthUInt.swift" - "**/.build/" +type_name: + allowed_symbols: "_" + line_length: warning: 120 ignores_urls: true diff --git a/Benchmarks/PrivateInformationRetrievalBenchmark/PirBenchmark.swift b/Benchmarks/PrivateInformationRetrievalBenchmark/PirBenchmark.swift index 9d4ab66b..840c202f 100644 --- a/Benchmarks/PrivateInformationRetrievalBenchmark/PirBenchmark.swift +++ b/Benchmarks/PrivateInformationRetrievalBenchmark/PirBenchmark.swift @@ -18,14 +18,15 @@ import _BenchmarkUtilities import HomomorphicEncryption +import PrivateInformationRetrieval nonisolated(unsafe) let benchmarks: () -> Void = { - pirProcessBenchmark(Bfv.self)() - pirProcessBenchmark(Bfv.self)() + pirProcessBenchmark(PirUtil>.self)() + pirProcessBenchmark(PirUtil>.self)() - indexPirBenchmark(Bfv.self)() - indexPirBenchmark(Bfv.self)() + indexPirBenchmark(PirUtil>.self)() + indexPirBenchmark(PirUtil>.self)() - keywordPirBenchmark(Bfv.self)() - keywordPirBenchmark(Bfv.self)() + keywordPirBenchmark(PirUtil>.self)() + keywordPirBenchmark(PirUtil>.self)() } diff --git a/Benchmarks/RlweBenchmark/RlweBenchmark.swift b/Benchmarks/RlweBenchmark/RlweBenchmark.swift index c5fb2294..770943f6 100644 --- a/Benchmarks/RlweBenchmark/RlweBenchmark.swift +++ b/Benchmarks/RlweBenchmark/RlweBenchmark.swift @@ -43,7 +43,7 @@ func getRandomPlaintextData(count: Int, in range: Range) -> [T struct RlweBenchmarkContext: Sendable { var encryptionParameters: EncryptionParameters - var context: Context + var context: Scheme.Context let data: [Scheme.Scalar] let signedData: [Scheme.SignedScalar] @@ -69,7 +69,7 @@ struct RlweBenchmarkContext: Sendable { Scheme.Scalar.self), errorStdDev: ErrorStdDev.stdDev32, securityLevel: SecurityLevel.quantum128) - self.context = try Context(encryptionParameters: encryptionParameters) + self.context = try Scheme.Context(encryptionParameters: encryptionParameters) self.secretKey = try context.generateSecretKey() let columnElement = GaloisElement.swappingRows(degree: polyDegree) let rowElement = try GaloisElement.rotatingColumns(by: rotateColumnsStep, degree: polyDegree) diff --git a/Package.resolved b/Package.resolved index 3e153878..fab18367 100644 --- a/Package.resolved +++ b/Package.resolved @@ -28,6 +28,24 @@ "version" : "1.4.0" } }, + { + "identity" : "swift-async-algorithms", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-async-algorithms.git", + "state" : { + "revision" : "042e1c4d9d19748c9c228f8d4ebc97bb1e339b0b", + "version" : "1.0.4" + } + }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "8c0c0a8b49e080e54e5e328cc552821ff07cd341", + "version" : "1.2.1" + } + }, { "identity" : "swift-crypto", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index af6b28a4..2cb9c04c 100644 --- a/Package.swift +++ b/Package.swift @@ -27,6 +27,22 @@ let executableSettings: [SwiftSetting] = let benchmarkSettings: [SwiftSetting] = [.unsafeFlags(["-cross-module-optimization"], .when(configuration: .release))] +let enableFlags = "SWIFT_HOMOMORPHIC_ENCRYPTION_MODULAR_ARITHMETIC_EXTRA_SWIFT_FLAGS" +func shouldEnableFlags() -> Bool { + if let flag = ProcessInfo.processInfo.environment[enableFlags], flag != "0", flag != "false" { + return true + } + return false +} + +var flags: [SwiftSetting] = [] +let enableFlagsBool = shouldEnableFlags() +if enableFlagsBool { + print("Building with additional flags. To disable, unset \(enableFlags) in your environment.") + let flagsAsString = (ProcessInfo.processInfo.environment[enableFlags] ?? "") as String + flags += [.unsafeFlags(flagsAsString.components(separatedBy: ","))] +} + let package = Package( name: "swift-homomorphic-encryption", products: [ @@ -60,6 +76,7 @@ let package = Package( dependencies: [ .package(url: "https://github.com/apple/swift-algorithms", from: "1.2.0"), .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.0"), + .package(url: "https://github.com/apple/swift-async-algorithms.git", from: "1.0.2"), .package(url: "https://github.com/apple/swift-crypto.git", from: "3.10.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-numerics", from: "1.0.0"), @@ -71,7 +88,7 @@ let package = Package( .target( name: "ModularArithmetic", dependencies: [], - swiftSettings: librarySettings), + swiftSettings: librarySettings + flags), .target( name: "CUtil", dependencies: [], @@ -100,12 +117,14 @@ let package = Package( .target( name: "PrivateInformationRetrieval", dependencies: ["HomomorphicEncryption", + .product(name: "AsyncAlgorithms", package: "swift-async-algorithms"), .product(name: "Numerics", package: "swift-numerics")], swiftSettings: librarySettings), .target( name: "PrivateNearestNeighborSearch", dependencies: [ .product(name: "Algorithms", package: "swift-algorithms"), + .product(name: "AsyncAlgorithms", package: "swift-async-algorithms"), "HomomorphicEncryption", "_HomomorphicEncryptionExtras", ], diff --git a/Snippets/HomomorphicEncryption/EncryptionParametersSnippet.swift b/Snippets/HomomorphicEncryption/EncryptionParametersSnippet.swift index ef12c2f7..0af120a2 100644 --- a/Snippets/HomomorphicEncryption/EncryptionParametersSnippet.swift +++ b/Snippets/HomomorphicEncryption/EncryptionParametersSnippet.swift @@ -85,7 +85,7 @@ func summarize( parameters: EncryptionParameters, _: Scheme.Type) throws { let values = (0..<8).map { Scheme.Scalar($0) } - let context = try Context(encryptionParameters: parameters) + let context = try Scheme.Context(encryptionParameters: parameters) let plaintext: Scheme.CoeffPlaintext = try context.encode( values: values, format: .coefficient) diff --git a/Sources/ApplicationProtobuf/PirConversion.swift b/Sources/ApplicationProtobuf/PirConversion.swift index 86c3277f..8410a0f5 100644 --- a/Sources/ApplicationProtobuf/PirConversion.swift +++ b/Sources/ApplicationProtobuf/PirConversion.swift @@ -24,7 +24,7 @@ extension Apple_SwiftHomomorphicEncryption_Pir_V1_EncryptedIndices { /// - Parameter context: Context to associate with the native type. /// - Returns: The converted native type. /// - Throws: Error upon invalid protobuf object. - public func native(context: Context) throws -> Query { + public func native(context: Scheme.Context) throws -> Query { let ciphertexts: [Scheme.CanonicalCiphertext] = try ciphertexts.map { ciphertext in let serializedCiphertext: SerializedCiphertext = try ciphertext.native() return try Ciphertext( @@ -54,7 +54,7 @@ extension ProcessedDatabaseWithParameters { /// - Parameter context: The context that was used to create processed database. /// - Returns: The PIR parameters protobuf object. /// - Throws: Error when the parameters cannot be represented as a protobuf object. - public func proto(context: Context) throws -> Apple_SwiftHomomorphicEncryption_Pir_V1_PirParameters { + public func proto(context: Scheme.Context) throws -> Apple_SwiftHomomorphicEncryption_Pir_V1_PirParameters { let encryptionParameters = context.encryptionParameters return try Apple_SwiftHomomorphicEncryption_Pir_V1_PirParameters.with { params in params.encryptionParameters = try encryptionParameters.proto(scheme: Scheme.self) diff --git a/Sources/ApplicationProtobuf/PirConversionApi.swift b/Sources/ApplicationProtobuf/PirConversionApi.swift index f9b00b8a..2ceb6d53 100644 --- a/Sources/ApplicationProtobuf/PirConversionApi.swift +++ b/Sources/ApplicationProtobuf/PirConversionApi.swift @@ -21,7 +21,7 @@ extension Apple_SwiftHomomorphicEncryption_Api_Pir_V1_PIRResponse { /// - Parameter context: Context to associate with the native type. /// - Returns: The converted native type. /// - Throws: Error upon invalid protobuf object. - public func native(context: Context) throws -> Response { + public func native(context: Scheme.Context) throws -> Response { let ciphertexts: [[Scheme.CoeffCiphertext]] = try replies.map { reply in let serializedCiphertexts: [SerializedCiphertext] = try reply.native() return try serializedCiphertexts.map { serialized in diff --git a/Sources/ApplicationProtobuf/PnnsConversion.swift b/Sources/ApplicationProtobuf/PnnsConversion.swift index aaebf4c8..5c410748 100644 --- a/Sources/ApplicationProtobuf/PnnsConversion.swift +++ b/Sources/ApplicationProtobuf/PnnsConversion.swift @@ -315,7 +315,7 @@ extension [Apple_SwiftHomomorphicEncryption_Pnns_V1_SerializedCiphertextMatrix] /// Converts the native object into a protobuf object. /// - Returns: The converted protobuf object. /// - Throws: Error upon unsupported object. - public func native(context: Context) throws -> Query { + public func native(context: Scheme.Context) throws -> Query { let matrices: [CiphertextMatrix] = try map { matrix in let native: SerializedCiphertextMatrix = try matrix.native() return try CiphertextMatrix(deserialize: native, context: context) diff --git a/Sources/ApplicationProtobuf/PnnsConversionApi.swift b/Sources/ApplicationProtobuf/PnnsConversionApi.swift index 2aa3d6db..eab0c600 100644 --- a/Sources/ApplicationProtobuf/PnnsConversionApi.swift +++ b/Sources/ApplicationProtobuf/PnnsConversionApi.swift @@ -23,7 +23,7 @@ extension Apple_SwiftHomomorphicEncryption_Api_Pnns_V1_PNNSShardResponse { /// - Parameter contexts: Contexts to associate with the native type; one context per plaintext modulus. /// - Returns: The converted native type. /// - Throws: Error upon invalid protobuf object. - public func native(contexts: [Context]) throws -> Response { + public func native(contexts: [Scheme.Context]) throws -> Response { precondition(contexts.count == reply.count) let matrices: [CiphertextMatrix] = try zip(reply, contexts).map { matrix, context in let serialized: SerializedCiphertextMatrix = try matrix.native() diff --git a/Sources/HomomorphicEncryption/Array2d.swift b/Sources/HomomorphicEncryption/Array2d.swift index 66993d02..c4097d8a 100644 --- a/Sources/HomomorphicEncryption/Array2d.swift +++ b/Sources/HomomorphicEncryption/Array2d.swift @@ -75,6 +75,38 @@ public struct Array2d: Equatable, rowCount: rowCount, columnCount: columnCount) } + + /// Provides scoped access to the underlying buffer storing the array's data. + /// + /// Use this method when you need temporary read-only access to the array's contiguous storage. + /// The buffer pointer is only valid for the duration of the closure's execution. + /// + /// - Parameter body: A closure that takes an `UnsafeBufferPointer` to the array's data. + /// The buffer pointer argument is valid only for the duration of the closure's execution. + /// - Returns: The return value of the `body` closure. + /// - Throws: Rethrows any error thrown by the `body` closure. + public func withUnsafeData(_ body: (UnsafeBufferPointer) throws -> Return) rethrows -> Return { + try data.withUnsafeBufferPointer { pointer in + try body(pointer) + } + } + + /// Provides scoped access to the underlying buffer storing the array's data for mutation. + /// + /// Use this method when you need temporary read-write access to the array's contiguous storage. + /// The buffer pointer is only valid for the duration of the closure's execution. + /// + /// - Parameter body: A closure that takes an `UnsafeMutableBufferPointer` to the array's data. + /// The buffer pointer argument is valid only for the duration of the closure's execution. + /// - Returns: The return value of the `body` closure. + /// - Throws: Rethrows any error thrown by the `body` closure. + public mutating func withUnsafeMutableData(_ body: (UnsafeMutableBufferPointer) throws + -> Return) rethrows -> Return + { + try data.withUnsafeMutableBufferPointer { pointer in + try body(pointer) + } + } } extension Array2d { diff --git a/Sources/HomomorphicEncryption/Bfv/Bfv+Decrypt.swift b/Sources/HomomorphicEncryption/Bfv/Bfv+Decrypt.swift index 258af3ed..037e0b12 100644 --- a/Sources/HomomorphicEncryption/Bfv/Bfv+Decrypt.swift +++ b/Sources/HomomorphicEncryption/Bfv/Bfv+Decrypt.swift @@ -35,7 +35,7 @@ extension Bfv { let rnsTool = ciphertext.context.getRnsTool(moduliCount: dotProduct.moduli.count) let plaintext = try rnsTool.scaleAndRound(poly: dotProduct, scalingFactor: scalingFactor) - return CoeffPlaintext(context: ciphertext.context, poly: plaintext) + return try CoeffPlaintext(context: ciphertext.context, poly: plaintext) } /// Calculates the number of least significant bits (LSBs) per polynomial that can be excluded diff --git a/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift b/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift index a40433d2..ba44dab0 100644 --- a/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift +++ b/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +import ModularArithmetic + extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes @@ -24,7 +26,7 @@ extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes - public static func encode(context: Context>, values: some Collection, + public static func encode(context: Context, values: some Collection, format: EncodeFormat) throws -> CoeffPlaintext { try context.encode(values: values, format: format) @@ -32,7 +34,7 @@ extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes - public static func encode(context: Context>, signedValues: some Collection, + public static func encode(context: Context, signedValues: some Collection, format: EncodeFormat) throws -> CoeffPlaintext { try context.encode(signedValues: signedValues, format: format) @@ -40,7 +42,7 @@ extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes - public static func encode(context: Context>, values: some Collection, format: EncodeFormat, + public static func encode(context: Context, values: some Collection, format: EncodeFormat, moduliCount: Int?) throws -> EvalPlaintext { let coeffPlaintext = try Self.encode(context: context, values: values, format: format) @@ -50,7 +52,7 @@ extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes public static func encode( - context: Context>, + context: Context, signedValues: some Collection, format: EncodeFormat, moduliCount: Int?) throws -> EvalPlaintext @@ -82,4 +84,69 @@ extension Bfv { public static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [SignedScalar] { try plaintext.convertToCoeffFormat().decode(format: format) } + + /// Calculates the number of least significant bits (LSBs) per polynomial that can be excluded + /// from serialization of a single-modulus ciphertext, when decryption is performed immediately after + /// deserialization. + /// + /// In BFV, the LSB bits of each polynomial may be excluded from the serialization, + /// since they are rarely used in decryption. This yields a smaller + /// serialization size, at the cost of a small chance of decryption + /// error. + /// - seealso: Section 5.2 of . + @inlinable + public static func skipLSBsForDecryption(for parameters: EncryptionParameters) -> [Int] { + let q0 = parameters.coefficientModuli[0] + let t = parameters.plaintextModulus + // Note, Appendix F of the paper claims the low `l' = floor(log2(q/t))` bits of + // a message are unused during decryption. This is off by one, due to + // also needing the MSB decimal bit for correct rounding. + // Concretely, let x=7, t=5, q=64. Then, floor(log2(q/t)) = 3. + // Decrypting `x` yields `round(x * t / q) = round(0.546875) = 1`, + // whereas decrypting `(x >> 3) >> 3) = 0` yields `round(0 * t / q) = 0`. + // Hence, we subtract one from the definition of `l'` compared to the paper. + // + // Also, Appendix F fails to address ciphertext error. If the error + // is less than q/4t, then we have the error introduced by the dropped + // bits be less than q/4t so we can correctly decrypt. + let lPrime = if q0 >= 2 * t { + (q0 / t).log2 - 3 + } else { + 0 + } + + // Then, we want the error introduced by dropping + // bits to be `<= q/4p` since it is additive with the + // ciphertext error. Set number of bits dropped + // in `b` to `floor(log(q/8p))`. Next, estimate + // how many bits to drop from a so that + // w.h.p., the introduced error is less + // than q/8p. + // + // The paper uses `z_score * sqrt(2N/9) * 2^{l_a} + 2^{l_b} < 2^{l'}` + // Setting `l_b = l' - 1`, this yields + // `z_score * sqrt(2N/9) * 2^{l_a} < 2^{l'-1}`, iff + // `log2(z_score * sqrt(2N/9)) + l_a < l' - 1`, iff + // `l_a < l' - 1 - log2(z_score * sqrt(2N/9))`, which is true for + // `l_a = floor(l' - 1 - log2(z_score * sqrt(2N/9)))` + // The paper uses z_score = 7; we use a larger z_score since we are decrypting N + // coefficients, rather than a single LWE coefficient. This yields a + // a per-coefficient decryption error `Pr(|x| > z_score)`, where `x ~ N(0, 1)`. + // This yields a < 2^-49.5 per-coefficient decryption error. + // By union bound, the message decryption error is + // `< 2^(log2(N) - 49.5) = 2^-36.5` for `N=8192` + // + // We also add a check: if we're only dropping at most one bit in `a`, then + // it is safer to drop that bit in `b` instead since the error's + // dependence on `b` is deterministic. + var poly0SkipLSBs = max(lPrime, 0) + let zScore = 8.0 + let tmp = Int(zScore * (2.0 * Double(parameters.polyDegree) / 9.0).squareRoot()) + var poly1SkipLSBs = lPrime - (tmp == 0 ? 0 : tmp.ceilLog2) + if poly1SkipLSBs <= 1 { + poly0SkipLSBs = max(lPrime + 1, 0) + poly1SkipLSBs = 0 + } + return [poly0SkipLSBs, poly1SkipLSBs] + } } diff --git a/Sources/HomomorphicEncryption/Bfv/Bfv+Encrypt.swift b/Sources/HomomorphicEncryption/Bfv/Bfv+Encrypt.swift index b93359f5..d757f61e 100644 --- a/Sources/HomomorphicEncryption/Bfv/Bfv+Encrypt.swift +++ b/Sources/HomomorphicEncryption/Bfv/Bfv+Encrypt.swift @@ -1,4 +1,4 @@ -// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,24 +23,24 @@ extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes - public static func zeroCiphertextCoeff(context: Context, moduliCount: Int?) throws -> CoeffCiphertext { + public static func zeroCiphertextCoeff(context: Context, moduliCount: Int?) throws -> CoeffCiphertext { let moduliCount = moduliCount ?? context.ciphertextContext.moduli.count let zeroPoly = try PolyRq.zero( context: context.ciphertextContext .getContext(moduliCount: moduliCount)) let polys = [PolyRq](repeating: zeroPoly, count: Bfv.freshCiphertextPolyCount) - return Bfv.CoeffCiphertext(context: context, polys: polys, correctionFactor: 1) + return try Bfv.CoeffCiphertext(context: context, polys: polys, correctionFactor: 1) } @inlinable // swiftlint:disable:next missing_docs attributes - public static func zeroCiphertextEval(context: Context, moduliCount: Int?) throws -> EvalCiphertext { + public static func zeroCiphertextEval(context: Context, moduliCount: Int?) throws -> EvalCiphertext { let moduliCount = moduliCount ?? context.ciphertextContext.moduli.count let zeroPoly = try PolyRq.zero( context: context.ciphertextContext .getContext(moduliCount: moduliCount)) let polys = [PolyRq](repeating: zeroPoly, count: Bfv.freshCiphertextPolyCount) - return Bfv.EvalCiphertext(context: context, polys: polys, correctionFactor: 1) + return try Bfv.EvalCiphertext(context: context, polys: polys, correctionFactor: 1) } @inlinable @@ -143,7 +143,7 @@ extension Bfv { } @inlinable - static func encryptZero(for context: Context>, + static func encryptZero(for context: Context, using secretKey: SecretKey>) throws -> CanonicalCiphertext { let ciphertextContext = context.ciphertextContext @@ -151,7 +151,7 @@ extension Bfv { } @inlinable - static func encryptZero(for context: Context>, + static func encryptZero(for context: Context, using secretKey: SecretKey>, with ciphertextContext: PolyContext) throws -> CanonicalCiphertext { @@ -177,7 +177,7 @@ extension Bfv { errorPoly.zeroize() let aCoeff = try a.inverseNtt() - return CanonicalCiphertext( + return try CanonicalCiphertext( context: context, polys: [-c0, aCoeff], correctionFactor: 1, diff --git a/Sources/HomomorphicEncryption/Bfv/Bfv+Keys.swift b/Sources/HomomorphicEncryption/Bfv/Bfv+Keys.swift index 49c05ac7..f0d5306a 100644 --- a/Sources/HomomorphicEncryption/Bfv/Bfv+Keys.swift +++ b/Sources/HomomorphicEncryption/Bfv/Bfv+Keys.swift @@ -1,4 +1,4 @@ -// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ import ModularArithmetic extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes - public static func generateSecretKey(context: Context) throws -> SecretKey> { + public static func generateSecretKey(context: Context) throws -> SecretKey> { var s = PolyRq.zero(context: context.secretKeyContext) var rng = SystemRandomNumberGenerator() s.randomizeTernary(using: &rng) @@ -28,26 +28,26 @@ extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes public static func generateEvaluationKey( - context: Context>, + context: Context, config: EvaluationKeyConfig, using secretKey: borrowing SecretKey>) throws -> EvaluationKey> { guard context.supportsEvaluationKey else { throw HeError.unsupportedHeOperation() } - var galoisKeys: [Int: KeySwitchKey] = [:] + var galoisKeys: [Int: Self.KeySwitchKey] = [:] for element in config.galoisElements where !galoisKeys.keys.contains(element) { let switchedKey = try secretKey.poly.applyGalois(element: element) - galoisKeys[element] = try generateKeySwitchKey( + galoisKeys[element] = try _generateKeySwitchKey( context: context, currentKey: switchedKey, targetKey: secretKey) } - var galoisKey: GaloisKey? + var galoisKey: GaloisKey? if !galoisKeys.isEmpty { galoisKey = GaloisKey(keys: galoisKeys) } - var relinearizationKey: RelinearizationKey? + var relinearizationKey: _RelinearizationKey? if config.hasRelinearizationKey { relinearizationKey = try Self.generateRelinearizationKey(context: context, secretKey: secretKey) } @@ -55,19 +55,20 @@ extension Bfv { } @inlinable - static func generateRelinearizationKey(context: Context, + static func generateRelinearizationKey(context: Context, secretKey: borrowing SecretKey) throws - -> RelinearizationKey + -> _RelinearizationKey { let s2 = secretKey.poly * secretKey.poly - let keySwitchingKey = try generateKeySwitchKey(context: context, currentKey: s2, targetKey: secretKey) - return RelinearizationKey(keySwitchKey: keySwitchingKey) + let keySwitchingKey = try _generateKeySwitchKey(context: context, currentKey: s2, targetKey: secretKey) + return _RelinearizationKey(keySwitchKey: keySwitchingKey) } + /// Generate the key switching key from current key to target key. @inlinable - static func generateKeySwitchKey(context: Context>, - currentKey: consuming PolyRq, - targetKey: borrowing SecretKey>) throws -> KeySwitchKey> + public static func _generateKeySwitchKey(context: Context, + currentKey: consuming PolyRq, + targetKey: borrowing SecretKey>) throws -> _KeySwitchKey> { guard let keyModulus = context.coefficientModuli.last else { throw HeError.invalidEncryptionParameters(context.encryptionParameters) @@ -98,7 +99,7 @@ extension Bfv { currentKey.zeroize() _ = consume currentKey - return KeySwitchKey(context: context, ciphers: ciphers) + return KeySwitchKey(context: context, ciphertexts: ciphers) } /// Computes the key-switching update of a target polynomial. @@ -119,10 +120,10 @@ extension Bfv { /// - Throws: Error upon failure to compute key-switching update. /// - seealso: ``Bfv/generateEvaluationKey(context:config:using:)``. @inlinable - static func computeKeySwitchingUpdate( - context: Context>, + public static func _computeKeySwitchingUpdate( + context: Context, target: PolyRq, - keySwitchingKey: KeySwitchKey) throws -> [PolyRq] + keySwitchingKey: Self.KeySwitchKey) throws -> [PolyRq] { // The implementation loosely follows the outline on page 36 of . // The inner product is computed in an extended base `q_0, q_1, ..., q_l, q_{ks}`, where `q_{ks}` is the special @@ -138,16 +139,16 @@ extension Bfv { } let keySwitchingModuli = keySwitchingContext.reduceModuli - let keyComponentCount = keySwitchingKey.ciphers[0].polys.count + let keyComponentCount = keySwitchingKey.ciphertexts[0].polys.count let polys = [PolyRq]( repeating: PolyRq.zero(context: keySwitchingContext), count: keyComponentCount) - var ciphertextProd: EvalCiphertext = Ciphertext(context: context, - polys: polys, - correctionFactor: 1) + var ciphertextProd: EvalCiphertext = try Ciphertext(context: context, + polys: polys, + correctionFactor: 1) let targetCoeff = try target.convertToCoeffFormat() - let keyCiphers = keySwitchingKey.ciphers + let keyCiphers = keySwitchingKey.ciphertexts for rnsIndex in 0..: HeScheme { + public typealias CiphertextAuxiliaryData = EmptyAuxiliary + public typealias PlaintextAuxiliaryData = EmptyAuxiliary + + public typealias Context = HomomorphicEncryption.Context + public typealias KeySwitchKey = HomomorphicEncryption._KeySwitchKey + public typealias GaloisKey = HomomorphicEncryption._GaloisKey + public typealias Scalar = T + public typealias SignedScalar = T.SignedScalar public typealias CanonicalCiphertextFormat = Coeff + public static var cryptosystem: HeCryptoSystem { .bfv } + public static var freshCiphertextPolyCount: Int { 2 } @@ -174,7 +184,7 @@ public enum Bfv: HeScheme { } ciphertext.polys[0] = ciphertext.polys[0].applyGalois(element: element) let tempC1 = ciphertext.polys[1].applyGalois(element: element) - let update = try Self.computeKeySwitchingUpdate( + let update = try Self._computeKeySwitchingUpdate( context: ciphertext.context, target: tempC1, keySwitchingKey: keySwitchingKey) @@ -194,7 +204,7 @@ public enum Bfv: HeScheme { guard let relinearizationKey = key.relinearizationKey else { throw HeError.missingRelinearizationKey } - let update = try Self.computeKeySwitchingUpdate( + let update = try Self._computeKeySwitchingUpdate( context: ciphertext.context, target: poly2, keySwitchingKey: relinearizationKey.keySwitchKey) @@ -251,7 +261,7 @@ public enum Bfv: HeScheme { reduceInPlace(accumulator: &accumulator, polyContext: rnsTool.qBskContext) } } - var sum = EvalCiphertext( + var sum = try EvalCiphertext( context: firstCiphertext.context, polys: Array(repeating: .zero(context: rnsTool.qBskContext), count: 3), correctionFactor: 1) @@ -341,20 +351,30 @@ public enum Bfv: HeScheme { } @inlinable - public static func forwardNtt(_ ciphertext: CoeffCiphertext) throws -> EvalCiphertext { + public static func forwardNtt(_ ciphertext: inout CoeffCiphertext) throws -> EvalCiphertext { let polys = try ciphertext.polys.map { try $0.forwardNtt() } - return Ciphertext, Eval>(context: ciphertext.context, - polys: polys, - correctionFactor: ciphertext.correctionFactor, - seed: ciphertext.seed) + return try Ciphertext, Eval>(context: ciphertext.context, + polys: polys, + correctionFactor: ciphertext.correctionFactor, + seed: ciphertext.seed) } @inlinable - public static func inverseNtt(_ ciphertext: EvalCiphertext) throws -> CoeffCiphertext { + public static func inverseNtt(_ ciphertext: inout EvalCiphertext) throws -> CoeffCiphertext { let polys = try ciphertext.polys.map { try $0.inverseNtt() } - return Ciphertext, Coeff>(context: ciphertext.context, - polys: polys, - correctionFactor: ciphertext.correctionFactor, - seed: ciphertext.seed) + return try Ciphertext, Coeff>(context: ciphertext.context, + polys: polys, + correctionFactor: ciphertext.correctionFactor, + seed: ciphertext.seed) + } + + /// Returns the dimension counts for ``EncodeFormat/simd`` encoding, or `nil` if the HE scheme does + /// not support SIMD encoding for the given parameters. + @inlinable + public static func simdDimensions(for encryptionParameter: EncryptionParameters) -> SimdEncodingDimensions? { + guard encryptionParameter.supportsSimdEncoding else { + return nil + } + return SimdEncodingDimensions(rowCount: 2, columnCount: encryptionParameter.polyDegree / 2) } } diff --git a/Sources/HomomorphicEncryption/Ciphertext.swift b/Sources/HomomorphicEncryption/Ciphertext.swift index 9628ad7c..8751460b 100644 --- a/Sources/HomomorphicEncryption/Ciphertext.swift +++ b/Sources/HomomorphicEncryption/Ciphertext.swift @@ -17,10 +17,12 @@ public struct Ciphertext: Equatable, Senda public typealias Scalar = Scheme.Scalar /// Context for HE computation. - public let context: Context - @usableFromInline package var polys: [PolyRq] - @usableFromInline var correctionFactor: Scalar - @usableFromInline var seed: [UInt8] = [] + public let context: Scheme.Context + public var polys: [PolyRq] + public var correctionFactor: Scalar + public var seed: [UInt8] = [] + + public var auxiliaryData: Scheme.CiphertextAuxiliaryData /// The number of polynomials in the ciphertext. /// @@ -33,15 +35,50 @@ public struct Ciphertext: Equatable, Senda @inlinable init( - context: Context, + context: Scheme.Context, polys: [PolyRq], correctionFactor: Scalar, - seed: [UInt8] = []) + seed: [UInt8] = []) throws + { + try self.init( + _context: context, + _polys: polys, + _correctionFactor: correctionFactor, + _auxiliaryData: nil, + _seed: seed) + } + + /// Create a ciphertext with the given content. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + /// - Parameters: + /// - _context: context of the ciphertext. + /// - _polys: polys of the ciphertext. + /// - _correctionFactor: correction factor of the ciphertext. + /// - _seed: seed of the ciphertext. + /// - _auxiliaryData: optionally provided auxiliary ciphertext data, but explicitly put `nil` if one should be + /// created. + /// - Throws: error occurred when creating the auxiliary data. + @inlinable + public init( + _context context: Scheme.Context, + _polys polys: [PolyRq], + _correctionFactor correctionFactor: Scheme.Scalar, + _auxiliaryData auxiliaryData: Scheme.CiphertextAuxiliaryData?, + _seed seed: [UInt8] = []) throws { self.context = context self.polys = polys self.correctionFactor = correctionFactor self.seed = seed + if let auxData = auxiliaryData { + self.auxiliaryData = auxData + } else { + self.auxiliaryData = try Scheme.CiphertextAuxiliaryData( + context: context, + polys: polys, + correctionFactor: correctionFactor, + seed: seed) + } } /// Generates a ciphertext of zeros. @@ -64,7 +101,9 @@ public struct Ciphertext: Equatable, Senda /// ``` /// - seelaso: ``Ciphertext/isTransparent()`` @inlinable - public static func zero(context: Context, moduliCount: Int? = nil) throws -> Ciphertext { + public static func zero(context: Scheme.Context, + moduliCount: Int? = nil) throws -> Ciphertext + { try Scheme.zero(context: context, moduliCount: moduliCount) } @@ -163,23 +202,23 @@ public struct Ciphertext: Equatable, Senda } @inlinable - package func forwardNtt() throws -> Ciphertext where Format == Coeff { - try Scheme.forwardNtt(self) + package consuming func forwardNtt() throws -> Ciphertext where Format == Coeff { + try Scheme.forwardNtt(&self) } @inlinable - package func inverseNtt() throws -> Ciphertext where Format == Eval { - try Scheme.inverseNtt(self) + package consuming func inverseNtt() throws -> Ciphertext where Format == Eval { + try Scheme.inverseNtt(&self) } /// Converts the ciphertext to a ``HeScheme/CoeffCiphertext``. /// - Returns: The converted ciphertext. /// - Throws: Error upon failure to convert the ciphertext. @inlinable - public func convertToCoeffFormat() throws -> Ciphertext { + public consuming func convertToCoeffFormat() throws -> Ciphertext { if Format.self == Eval.self { - if let ciphertext = self as? Ciphertext { - return try ciphertext.inverseNtt() + if var ciphertext = self as? Ciphertext { + return try Scheme.inverseNtt(&ciphertext) } throw HeError.errorCastingPolyFormat(from: Format.self, to: Eval.self) } @@ -193,10 +232,10 @@ public struct Ciphertext: Equatable, Senda /// - Returns: The converted ciphertext. /// - Throws: Error upon failure to convert the ciphertext. @inlinable - public func convertToEvalFormat() throws -> Ciphertext { + public consuming func convertToEvalFormat() throws -> Ciphertext { if Format.self == Coeff.self { - if let ciphertext = self as? Ciphertext { - return try ciphertext.forwardNtt() + if var ciphertext = self as? Ciphertext { + return try Scheme.forwardNtt(&ciphertext) } throw HeError.errorCastingPolyFormat(from: Format.self, to: Coeff.self) } @@ -210,7 +249,7 @@ public struct Ciphertext: Equatable, Senda /// - Returns: The converted ciphertext. /// - Throws: Error upon failure to convert the ciphertext. @inlinable - public func convertToCanonicalFormat() throws -> Ciphertext { + public consuming func convertToCanonicalFormat() throws -> Ciphertext { if Scheme.CanonicalCiphertextFormat.self == Coeff.self { // swiftlint:disable:next force_cast return try convertToCoeffFormat() as! Scheme.CanonicalCiphertext @@ -463,6 +502,7 @@ extension Ciphertext where Format == Coeff { /// - Throws: Error upon failure to compute the inverse. @inlinable public mutating func multiplyInversePowerOfX(power: Int) throws { + precondition(power >= 0) try Scheme.multiplyInversePowerOfX(&self, power: power) } } @@ -554,3 +594,82 @@ extension Collection { try Scheme.innerProduct(self, ciphertexts) } } + +/// Async ciphertext functions. +extension Ciphertext { + /// Converts the ciphertext to coefficient format asynchronously. + /// + /// This method performs an asynchronous conversion of the ciphertext to ``Coeff`` format. + /// If the ciphertext is already in coefficient format, it returns the ciphertext unchanged. + /// If the ciphertext is in evaluation (``Eval``) format, it performs an inverse NTT (Number Theoretic Transform) + /// to convert it to coefficient format. + /// + /// - Returns: A ciphertext in coefficient format. + /// - Throws: ``HeError/errorCastingPolyFormat(_:)`` if the format conversion fails. + /// - SeeAlso: ``convertToEvalFormat()-9jgqm`` to convert to evaluation format. + /// - SeeAlso: ``convertToCanonicalFormat()-wgss`` to convert to the scheme's canonical format. + @inlinable + public consuming func convertToCoeffFormat() async throws -> Ciphertext { + if Format.self == Eval.self { + if var ciphertext = self as? Ciphertext { + return try await Scheme.inverseNttAsync(&ciphertext) + } + throw HeError.errorCastingPolyFormat(from: Format.self, to: Eval.self) + } + if let ciphertext = self as? Ciphertext { + return ciphertext + } + throw HeError.errorCastingPolyFormat(from: Format.self, to: Coeff.self) + } + + /// Converts the ciphertext to evaluation format asynchronously. + /// + /// This method performs an asynchronous conversion of the ciphertext to ``Eval`` format. + /// If the ciphertext is already in evaluation format, it returns the ciphertext unchanged. + /// If the ciphertext is in coefficient (``Coeff``) format, it performs a forward NTT (Number Theoretic Transform) + /// to convert it to evaluation format. + /// + /// - Returns: A ciphertext in evaluation format. + /// - Throws: ``HeError/errorCastingPolyFormat(_:)`` if the format conversion fails. + /// - SeeAlso: ``convertToCoeffFormat()-a2ay`` to convert to coefficient format. + /// - SeeAlso: ``convertToCanonicalFormat()-wgss`` to convert to the scheme's canonical format. + @inlinable + public consuming func convertToEvalFormat() async throws -> Ciphertext { + if Format.self == Coeff.self { + if var ciphertext = self as? Ciphertext { + return try await Scheme.forwardNttAsync(&ciphertext) + } + throw HeError.errorCastingPolyFormat(from: Format.self, to: Coeff.self) + } + if let ciphertext = self as? Ciphertext { + return ciphertext + } + throw HeError.errorCastingPolyFormat(from: Format.self, to: Eval.self) + } + + /// Converts the ciphertext to the scheme's canonical format. + /// + /// If the ciphertext is already in the canonical format, it returns the ciphertext unchanged. + /// Otherwise, it performs the necessary Number Theoretic Transform (NTT) conversion: + /// - Forward NTT if converting from coefficient to evaluation format + /// - Inverse NTT if converting from evaluation to coefficient format + /// + /// - Returns: A ciphertext in the canonical format. + /// - Throws: ``HeError/errorCastingPolyFormat(_:)`` if the format conversion fails. + /// - SeeAlso: ``convertToCoeffFormat()-a2ay`` to convert to coefficient format. + /// - seeAlso: ``convertToEvalFormat()-9jgqm`` to convert to evaluation format. + @inlinable + public consuming func convertToCanonicalFormat() async throws + -> Ciphertext + { + if Scheme.CanonicalCiphertextFormat.self == Coeff.self { + // swiftlint:disable:next force_cast + return try await convertToCoeffFormat() as! Scheme.CanonicalCiphertext + } + if Scheme.CanonicalCiphertextFormat.self == Eval.self { + // swiftlint:disable:next force_cast + return try await convertToEvalFormat() as! Scheme.CanonicalCiphertext + } + throw HeError.errorCastingPolyFormat(from: Format.self, to: Scheme.CanonicalCiphertextFormat.self) + } +} diff --git a/Sources/HomomorphicEncryption/Context.swift b/Sources/HomomorphicEncryption/Context.swift index c5ee84be..ddfb0b7d 100644 --- a/Sources/HomomorphicEncryption/Context.swift +++ b/Sources/HomomorphicEncryption/Context.swift @@ -16,33 +16,33 @@ /// /// HE operations are typically only supported between objects, such as ``Ciphertext``, ``Plaintext``, /// ``EvaluationKey``, ``SecretKey``, with the same context. -public final class Context: Equatable, Sendable { +public final class Context: HeContext, Equatable, Sendable { public typealias Scalar = Scheme.Scalar /// Encryption parameters. public let encryptionParameters: EncryptionParameters /// Plaintext context, with modulus `t`, the plaintext modulus. - @usableFromInline let plaintextContext: PolyContext + public let plaintextContext: PolyContext - /// Encoding matrix for ``Encoding.simd`` encoding. - @usableFromInline let simdEncodingMatrix: [Int] + /// Encoding matrix for SIMD encoding. + public let simdEncodingMatrix: [Int] /// Context for the secret key. - @usableFromInline let secretKeyContext: PolyContext + public let secretKeyContext: PolyContext /// Top-level ciphertext context. - @usableFromInline package let ciphertextContext: PolyContext + public let ciphertextContext: PolyContext /// Contexts for key-switching keys. /// /// The i'th context contains `q_0, ..., q_i, q_{L-1}`, and has next context dropping `q_{L-1}` /// E.g., `keySwitchingContexts[0].context.moduli = [q_0, q_1, q_L]`, and /// `keySwitchingContexts[0].next.moduli = [q_0, q_1]` - @usableFromInline let keySwitchingContexts: [PolyContext] + public let keySwitchingContexts: [PolyContext] /// The rns tools for each level of ciphertexts, with number of moduli in descending order. - @usableFromInline let rnsTools: [RnsTool] + public let _rnsTools: [_RnsTool] /// The plaintext modulus,`t`. public var plaintextModulus: Scalar { encryptionParameters.plaintextModulus } @@ -85,7 +85,7 @@ public final class Context: Equatable, Sendable { } else { nil } - var rnsTools = [RnsTool]() + var rnsTools = [_RnsTool]() rnsTools.reserveCapacity(ciphertextModuli.count) let ciphertextContext = try PolyContext( degree: encryptionParameters.polyDegree, @@ -110,16 +110,16 @@ public final class Context: Equatable, Sendable { degree: encryptionParameters.polyDegree, moduli: [encryptionParameters.plaintextModulus]) - let rnsToolContext = try RnsTool.RnsToolContext( + let rnsToolContext = try _RnsTool.RnsToolContext( inputContext: ciphertextContext, outputContext: plaintextContext) var rnsToolsCiphertextContext = ciphertextContext - try rnsTools.append(RnsTool(from: ciphertextContext, to: plaintextContext, rnsToolContext: rnsToolContext)) + try rnsTools.append(_RnsTool(from: ciphertextContext, to: plaintextContext, rnsToolContext: rnsToolContext)) while let nextContext = rnsToolsCiphertextContext.next { - try rnsTools.append(RnsTool(from: nextContext, to: plaintextContext, rnsToolContext: rnsToolContext)) + try rnsTools.append(_RnsTool(from: nextContext, to: plaintextContext, rnsToolContext: rnsToolContext)) rnsToolsCiphertextContext = nextContext } - self.rnsTools = rnsTools + self._rnsTools = rnsTools } /// Returns a boolean value indicating whether two contexts are equal. @@ -133,9 +133,9 @@ public final class Context: Equatable, Sendable { } @inlinable - func getRnsTool(moduliCount: Int) -> RnsTool { - precondition(moduliCount <= rnsTools.count && moduliCount > 0, "Invalid number of moduli") - return rnsTools[rnsTools.count - moduliCount] + public func getRnsTool(moduliCount: Int) -> _RnsTool { + precondition(moduliCount <= _rnsTools.count && moduliCount > 0, "Invalid number of moduli") + return _rnsTools[_rnsTools.count - moduliCount] } } diff --git a/Sources/HomomorphicEncryption/CrtComposer.swift b/Sources/HomomorphicEncryption/CrtComposer.swift index 4579f42f..9b76beac 100644 --- a/Sources/HomomorphicEncryption/CrtComposer.swift +++ b/Sources/HomomorphicEncryption/CrtComposer.swift @@ -14,14 +14,16 @@ import ModularArithmetic -/// Performs Chinese remainder theorem (CRT) composition of coefficients. @usableFromInline -package struct CrtComposer: Sendable { +package typealias CrtComposer = _CrtComposer + +/// Performs Chinese remainder theorem (CRT) composition of coefficients. +public struct _CrtComposer: Sendable { /// Context for the CRT moduli `q_i`. - @usableFromInline let polyContext: PolyContext + public let polyContext: PolyContext /// i'th entry stores `(q_i / q) % q_i`. - @usableFromInline let inversePuncturedProducts: [MultiplyConstantModulus] + public let inversePuncturedProducts: [MultiplyConstantModulus] /// Creates a new ``CrtComposer``. /// - Parameter polyContext: Context for the CRT moduli. diff --git a/Sources/HomomorphicEncryption/Encoding.swift b/Sources/HomomorphicEncryption/Encoding.swift index 61b98907..7770ffa2 100644 --- a/Sources/HomomorphicEncryption/Encoding.swift +++ b/Sources/HomomorphicEncryption/Encoding.swift @@ -17,7 +17,7 @@ import ModularArithmetic -extension Context { +extension HeContext { /// Encodes `values` in the given format. /// /// Encoding will use the top-level ciphertext context with all moduli. @@ -29,6 +29,7 @@ extension Context { @inlinable public func encode(values: some Collection, format: EncodeFormat) throws -> Plaintext + where Scheme.Context == Self { try validDataForEncoding(values: values) switch format { @@ -48,8 +49,9 @@ extension Context { /// - Returns: The plaintext encoding `signedValues`. /// - Throws: Error upon failure to encode. @inlinable - public func encode(signedValues: some Collection, - format: EncodeFormat) throws -> Plaintext + public func encode(signedValues: some Collection, + format: EncodeFormat) throws -> Plaintext + where Scheme.Context == Self { let signedModulus = Scheme.SignedScalar(plaintextModulus) let bounds = -(signedModulus >> 1)...((signedModulus - 1) >> 1) @@ -71,10 +73,9 @@ extension Context { /// - Returns: The plaintext encoding `values`. /// - Throws: Error upon failure to encode. @inlinable - public func encode( - values: some Collection, - format: EncodeFormat, - moduliCount: Int? = nil) throws -> Plaintext + public func encode(values: some Collection, format: EncodeFormat, + moduliCount: Int? = nil) throws -> Plaintext + where Scheme.Context == Self { try Scheme.encode(context: self, values: values, format: format, moduliCount: moduliCount) } @@ -88,8 +89,9 @@ extension Context { /// - Returns: The plaintext encoding `signedValues`. /// - Throws: Error upon failure to encode. @inlinable - public func encode(signedValues: some Collection, format: EncodeFormat, - moduliCount: Int? = nil) throws -> Plaintext + public func encode(signedValues: some Collection, format: EncodeFormat, + moduliCount: Int? = nil) throws -> Plaintext + where Scheme.Context == Self { try Scheme.encode(context: self, signedValues: signedValues, format: format, moduliCount: moduliCount) } @@ -134,7 +136,10 @@ extension Context { /// - Returns: The decoded signed values. /// - Throws: Error upon failure to decode. @inlinable - func decode(plaintext: Plaintext, format: EncodeFormat) throws -> [Scheme.SignedScalar] { + func decode(plaintext: Plaintext, + format: EncodeFormat) throws -> [Scheme.SignedScalar] where Scheme.Scalar == Scalar, + Context == Self + { try Scheme.decodeEval(plaintext: plaintext, format: format) } @@ -152,7 +157,7 @@ extension Context { } // functions for coefficient encoding/decoding -extension Context { +extension HeContext { /// Encodes a polynomial element-wise in coefficient format. /// /// Encodes the polynomial @@ -161,18 +166,17 @@ extension Context { @inlinable func encodeCoefficient(values: some Collection) throws -> Plaintext + where Scheme.Context == Self { if values.isEmpty { - return Plaintext(context: self, poly: PolyRq.zero(context: plaintextContext)) + return try Plaintext(context: self, poly: PolyRq.zero(context: plaintextContext)) } var valuesArray = Array(values) if valuesArray.count < degree { valuesArray.append(contentsOf: repeatElement(0, count: degree - valuesArray.count)) } let array: Array2d = Array2d(data: valuesArray, rowCount: 1, columnCount: valuesArray.count) - return Plaintext( - context: self, - poly: PolyRq(context: plaintextContext, data: array)) + return try Plaintext(context: self, poly: PolyRq(context: plaintextContext, data: array)) } /// Decodes a polynomial element-wise in coefficient format. @@ -188,7 +192,7 @@ extension Context { } // code for SIMD encoding/decoding -extension Context { +extension HeContext { @inlinable static func generateEncodingMatrix(encryptionParameters: EncryptionParameters) -> [Int] { guard encryptionParameters.plaintextModulus.isNttModulus(for: encryptionParameters.polyDegree) else { @@ -215,7 +219,9 @@ extension Context { } @inlinable - func encodeSimd(values: some Collection) throws -> Plaintext { + func encodeSimd(values: some Collection) throws -> Plaintext + where Scheme.Context == Self + { guard !simdEncodingMatrix.isEmpty else { throw HeError.simdEncodingNotSupported(for: encryptionParameters) } let polyDegree = encryptionParameters.polyDegree var array = Array2d.zero(rowCount: 1, columnCount: polyDegree) @@ -224,7 +230,7 @@ extension Context { } let poly = PolyRq<_, Eval>(context: plaintextContext, data: array) let coeffPoly = try poly.inverseNtt() - return Plaintext(context: self, poly: coeffPoly) + return try Plaintext(context: self, poly: coeffPoly) } @inlinable diff --git a/Sources/HomomorphicEncryption/Error.swift b/Sources/HomomorphicEncryption/Error.swift index 2392d4bb..49764d3f 100644 --- a/Sources/HomomorphicEncryption/Error.swift +++ b/Sources/HomomorphicEncryption/Error.swift @@ -109,7 +109,7 @@ extension HeError { } @inlinable - static func invalidContext(_ context: Context) -> Self { + static func invalidContext(_ context: some HeContext) -> Self { .invalidContext("\(context.description)") } @@ -168,7 +168,7 @@ extension HeError { } @inlinable - static func unequalContexts(got: Context, expected: Context) -> Self { + static func unequalContexts(got: some HeContext, expected: some HeContext) -> Self { .unequalContexts("Unequal contexts: \(got.description) is not equal to \(expected.description)") } diff --git a/Sources/HomomorphicEncryption/HeScheme.swift b/Sources/HomomorphicEncryption/HeScheme.swift index e1d09c4a..e9564f03 100644 --- a/Sources/HomomorphicEncryption/HeScheme.swift +++ b/Sources/HomomorphicEncryption/HeScheme.swift @@ -85,16 +85,99 @@ public struct SimdEncodingDimensions: Codable, Equatable, Hashable, Sendable { } } +public protocol HeContext: Equatable, Sendable, CustomStringConvertible { + associatedtype Scheme: HeScheme + associatedtype Scalar where Scalar == Scheme.Scalar + + var encryptionParameters: EncryptionParameters { get } + var ciphertextContext: PolyContext { get } + var plaintextContext: PolyContext { get } + var secretKeyContext: PolyContext { get } + var simdEncodingMatrix: [Int] { get } + var simdDimensions: SimdEncodingDimensions? { get } + + init(encryptionParameters: EncryptionParameters) throws + func getRnsTool(moduliCount: Int) throws -> _RnsTool +} + +extension HeContext { + /// The RLWE polynomial degree `N`. + public var degree: Int { encryptionParameters.polyDegree } + /// The plaintext modulus,`t`. + public var plaintextModulus: Scalar { encryptionParameters.plaintextModulus } + /// The coefficient moduli, `q_0, ..., q_L`. + public var coefficientModuli: [Scalar] { encryptionParameters.coefficientModuli } + /// Whether or not the context supports ``EncodeFormat/simd`` encoding. + public var supportsSimdEncoding: Bool { encryptionParameters.supportsSimdEncoding } + /// Whether or not the context supports use of an ``EvaluationKey``. + public var supportsEvaluationKey: Bool { encryptionParameters.supportsEvaluationKey } + /// The number of bits that can be encoded in a single ``Plaintext``. + public var bitsPerPlaintext: Int { encryptionParameters.bitsPerPlaintext } + /// The number of bytes that can be encoded in a single ``Plaintext``. + public var bytesPerPlaintext: Int { encryptionParameters.bytesPerPlaintext } +} + +public protocol HeKeySwitchKey: Equatable, Sendable { + associatedtype Scheme: HeScheme + var _context: Scheme.Context { get } + var _ciphertexts: [Ciphertext] { get } + + init(_context: Scheme.Context, _ciphertexts: [Ciphertext]) throws +} + +extension HeKeySwitchKey { + @usableFromInline var context: Scheme.Context { _context } + @usableFromInline var ciphertexts: [Ciphertext] { _ciphertexts } +} + +public protocol HeGaloisKey: Equatable, Sendable { + associatedtype Scheme: HeScheme + + var _keys: [Int: Scheme.KeySwitchKey] { get } + + init(_keys: [Int: Scheme.KeySwitchKey]) throws +} + +extension HeGaloisKey { + @usableFromInline var keys: [Int: Scheme.KeySwitchKey] { _keys } +} + +public protocol CiphertextAuxiliary: Equatable, Sendable { + associatedtype Scheme: HeScheme + init(context: Scheme.Context, + polys: [PolyRq], + correctionFactor: Scheme.Scalar, + seed: [UInt8]) throws +} + +public protocol PlaintextAuxiliary: Equatable, Sendable { + associatedtype Scheme: HeScheme + init(context: Scheme.Context, + poly: PolyRq) throws +} + /// Protocol for HE schemes. /// /// The protocol should be implemented when adding a new HE scheme. /// However, several functions have an alternative API which is more ergonomic and should be preferred. public protocol HeScheme: Sendable { + /// Associated auxiliary data for ciphertexts + associatedtype CiphertextAuxiliaryData: CiphertextAuxiliary where CiphertextAuxiliaryData.Scheme == Self + /// Associated auxiliary data for plaintexts + associatedtype PlaintextAuxiliaryData: PlaintextAuxiliary where PlaintextAuxiliaryData.Scheme == Self + /// Coefficient type for each polynomial. associatedtype Scalar: ScalarType /// Coefficient type for signed encoding/decoding. typealias SignedScalar = Scalar.SignedScalar + /// The context for the HE scheme. + associatedtype Context: HeContext where Context.Scheme == Self + /// The key switching key for the HE scheme. + associatedtype KeySwitchKey: HeKeySwitchKey where KeySwitchKey.Scheme == Self + /// The GaloisKey for the HE scheme + associatedtype GaloisKey: HeGaloisKey where GaloisKey.Scheme == Self + /// Polynomial format for the . associatedtype CanonicalCiphertextFormat: PolyFormat @@ -106,18 +189,19 @@ public protocol HeScheme: Sendable { /// Ciphertext in ``Coeff`` format. /// - /// ``Ciphertext/convertToCoeffFormat()`` can be used to convert a ciphertext to a ``CoeffCiphertext``. + /// ``Ciphertext/convertToCoeffFormat()-35q3d`` can be used to convert a ciphertext to a ``CoeffCiphertext``. typealias CoeffCiphertext = Ciphertext /// Ciphertext in ``Eval`` format. /// - /// ``Ciphertext/convertToEvalFormat()`` can be used to convert a ciphertext to an ``EvalCiphertext``. + /// ``Ciphertext/convertToEvalFormat()-8msby`` can be used to convert a ciphertext to an ``EvalCiphertext``. typealias EvalCiphertext = Ciphertext /// The canonical representation of a ciphertext. /// /// The canonical representation is the default ciphertext representation. - /// ``Ciphertext/convertToCanonicalFormat()`` can be used to convert a ciphertext to a ``CanonicalCiphertext``. + /// ``Ciphertext/convertToCanonicalFormat()-90lbz`` can be used to convert a ciphertext to a + /// ``CanonicalCiphertext``. /// However, some operations may require a specific format, such as ``CoeffCiphertext`` or ``EvalCiphertext``. typealias CanonicalCiphertext = Ciphertext @@ -127,6 +211,9 @@ public protocol HeScheme: Sendable { /// Evaluation key type. typealias EvaluationKey = HomomorphicEncryption.EvaluationKey + /// Underlying HE scheme. + static var cryptosystem: HeCryptoSystem { get } + /// The number of polynomials in a freshly encrypted ciphertext. /// /// Some operations such as ciphertext-ciphertext multiplication, or relinearization may change the number of @@ -139,12 +226,17 @@ public protocol HeScheme: Sendable { /// - seealso: ``Ciphertext/noiseBudget(using:variableTime:)``. static var minNoiseBudget: Double { get } + /// The (row, column) dimension counts for ``EncodeFormat/simd`` encoding. + /// + /// If the HE scheme does not support ``EncodeFormat/simd`` encoding, returns `nil`. + static func simdDimensions(for encryptionParameter: EncryptionParameters) -> SimdEncodingDimensions? + /// Generates a ``SecretKey``. /// - Parameter context: Context for HE computation. /// - Returns: A freshly generated secret key. /// - Throws: Error upon failure to generate a secret key. /// - seealso: ``Context/generateSecretKey()`` for an alternative API. - static func generateSecretKey(context: Context) throws -> SecretKey + static func generateSecretKey(context: Context) throws -> SecretKey /// Generates an ``EvaluationKey``. /// - Parameters: @@ -155,7 +247,7 @@ public protocol HeScheme: Sendable { /// - Throws: Error upon failure to generate an evaluation key. /// - seealso: ``Context/generateEvaluationKey(config:using:)`` for an alternative API. static func generateEvaluationKey( - context: Context, + context: Context, config: EvaluationKeyConfig, using secretKey: SecretKey) throws -> EvaluationKey @@ -173,7 +265,7 @@ public protocol HeScheme: Sendable { /// - Throws: Error upon failure to encode. /// - seealso: ``Context/encode(values:format:)`` for an alternative API. /// - seealso: ``HeScheme/encode(context:signedValues:format:)`` to encode signed values. - static func encode(context: Context, values: some Collection, format: EncodeFormat) throws + static func encode(context: Context, values: some Collection, format: EncodeFormat) throws -> CoeffPlaintext /// Encodes signed values into a plaintext with coefficient format. @@ -186,7 +278,10 @@ public protocol HeScheme: Sendable { /// - Throws: Error upon failure to encode. /// - seealso: ``Context/encode(signedValues:format:)`` for an alternative API. /// - seealso: ``HeScheme/encode(context:values:format:)`` to encode unsigned values. - static func encode(context: Context, signedValues: some Collection, format: EncodeFormat) throws + static func encode( + context: Context, + signedValues: some Collection, + format: EncodeFormat) throws -> CoeffPlaintext /// Encodes values into a plaintext with evaluation format. @@ -202,7 +297,7 @@ public protocol HeScheme: Sendable { /// - Throws: Error upon failure to encode. /// - seealso: ``Context/encode(values:format:moduliCount:)`` for an alternative API. /// - seealso: ``HeScheme/encode(context:signedValues:format:moduliCount:)`` to encode signed values. - static func encode(context: Context, values: some Collection, format: EncodeFormat, + static func encode(context: Context, values: some Collection, format: EncodeFormat, moduliCount: Int?) throws -> EvalPlaintext /// Encodes signed values into a plaintext with evaluation format. @@ -218,8 +313,11 @@ public protocol HeScheme: Sendable { /// - Throws: Error upon failure to encode. /// - seealso: ``Context/encode(signedValues:format:moduliCount:)`` for an alternative API. /// - seealso: ``HeScheme/encode(context:values:format:moduliCount:)`` to encode unsigned values. - static func encode(context: Context, signedValues: some Collection, format: EncodeFormat, - moduliCount: Int?) throws -> EvalPlaintext + static func encode( + context: Context, + signedValues: some Collection, + format: EncodeFormat, + moduliCount: Int?) throws -> EvalPlaintext /// Decodes a plaintext in ``Coeff`` format. /// - Parameters: @@ -257,6 +355,14 @@ public protocol HeScheme: Sendable { /// - seealso: ``Plaintext/decode(format:)-2agje`` for an alternative API. static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [SignedScalar] + /// Calculates the number of least significant bits (LSBs) per polynomial that can be excluded + /// from serialization of a single-modulus ciphertext, when decryption is performed immediately after + /// deserialization. + /// + /// - Parameter parameter: the concrete encryption parameter + /// - Returns: the lsbs to skip when decrypting a ciphertext + static func skipLSBsForDecryption(for parameter: EncryptionParameters) -> [Int] + /// Symmetric secret key encryption of a plaintext. /// - Parameters: /// - plaintext: Plaintext to encrypt. @@ -286,7 +392,7 @@ public protocol HeScheme: Sendable { /// ``` /// - seealso: ``HeScheme/isTransparent(ciphertext:)`` /// - seealso: ``Ciphertext/zero(context:moduliCount:)`` for an alternative API. - static func zeroCiphertextCoeff(context: Context, moduliCount: Int?) throws -> CoeffCiphertext + static func zeroCiphertextCoeff(context: Context, moduliCount: Int?) throws -> CoeffCiphertext /// Generates a ciphertext of zeros in ``Eval`` format. /// @@ -308,7 +414,7 @@ public protocol HeScheme: Sendable { /// ``` /// - seealso: ``HeScheme/isTransparent(ciphertext:)`` /// - seealso: ``Ciphertext/zero(context:moduliCount:)`` for an alternative API. - static func zeroCiphertextEval(context: Context, moduliCount: Int?) throws -> EvalCiphertext + static func zeroCiphertextEval(context: Context, moduliCount: Int?) throws -> EvalCiphertext /// Computes whether a ciphertext is transparent. /// @@ -824,32 +930,32 @@ public protocol HeScheme: Sendable { /// The async version of ``HeScheme/relinearize(_:using:)``. static func relinearizeAsync(_ ciphertext: inout CanonicalCiphertext, using key: EvaluationKey) async throws - /// Run the forward NTT algorithm on a given ciphertext in Coeff format + /// Run the forward NTT algorithm on a given ciphertext in Coeff format, the input may be consumed. /// - Parameter ciphertext: The ciphertext to run forward NTT on /// - Returns: The corresponding ciphertext in Eval format /// - Throws: Error upon failure to run forward NTT on the ciphertext. /// - seealso: ``forwardNttAsync(_:)`` for an async version of this API - static func forwardNtt(_ ciphertext: CoeffCiphertext) throws -> EvalCiphertext + static func forwardNtt(_ ciphertext: inout CoeffCiphertext) throws -> EvalCiphertext /// The async version of ``HeScheme/forwardNtt(_:)``. - static func forwardNttAsync(_ ciphertext: CoeffCiphertext) async throws -> EvalCiphertext + static func forwardNttAsync(_ ciphertext: inout CoeffCiphertext) async throws -> EvalCiphertext - /// Run the inverse NTT algorithm on a given ciphertext in Eval format + /// Run the inverse NTT algorithm on a given ciphertext in Eval format, the input may be consumed. /// - Parameter ciphertext: The ciphertext to run inverse NTT on /// - Returns: The corresponding ciphertext in Coeff format /// - Throws: Error upon failure to run inverse NTT on the ciphertext. /// - seealso: ``inverseNttAsync(_:)`` for an async version of this API - static func inverseNtt(_ ciphertext: EvalCiphertext) throws -> CoeffCiphertext + static func inverseNtt(_ ciphertext: inout EvalCiphertext) throws -> CoeffCiphertext /// The async version of ``HeScheme/inverseNtt(_:)``. - static func inverseNttAsync(_ ciphertext: EvalCiphertext) async throws -> CoeffCiphertext + static func inverseNttAsync(_ ciphertext: inout EvalCiphertext) async throws -> CoeffCiphertext /// Validates the equality of two contexts. /// - Parameters: /// - lhs: A Context to compare. /// - rhs: Another context to compare. /// - Throws: Error upon unequal contexts. - static func validateEquality(of lhs: Context, and rhs: Context) throws + static func validateEquality(of lhs: Context, and rhs: Context) throws /// Computes the noise budget of a ciphertext. /// @@ -900,6 +1006,18 @@ public protocol HeScheme: Sendable { static func multiplyInversePowerOfXAsync(_ ciphertext: inout CoeffCiphertext, power: Int) async throws } +/// Codify different HE schemes. +public enum HeCryptoSystem: String { + /// Brakerski-Fan-Vercauteren, as implemented in ``Bfv`` + case bfv + + /// NoOp encryption, ciphertexts are simply plaintexts. + case noOpScheme + + /// Other. + case unspecified +} + extension HeScheme { @inlinable // swiftlint:disable:next missing_docs attributes @@ -1187,7 +1305,7 @@ extension HeScheme { /// ``` /// - seelaso: ``Ciphertext/isTransparent()`` @inlinable - public static func zero(context: Context, + public static func zero(context: Context, moduliCount: Int? = nil) throws -> Ciphertext { if Format.self == Coeff.self { @@ -1257,7 +1375,7 @@ extension HeScheme { extension HeScheme { @inlinable // swiftlint:disable:next missing_docs attributes - public static func validateEquality(of lhs: Context, and rhs: Context) throws { + public static func validateEquality(of lhs: Context, and rhs: Context) throws { guard lhs == rhs else { throw HeError.unequalContexts(got: lhs, expected: rhs) } @@ -1357,15 +1475,13 @@ extension HeScheme { } } -// MARK: forwarding to Context - -extension Context { +extension HeContext { /// Generates a ``SecretKey``. /// - Returns: A freshly generated secret key. /// - Throws: Error upon failure to generate a secret key. /// - seealso: ``HeScheme/generateSecretKey(context:)`` for an alternative API. @inlinable - public func generateSecretKey() throws -> SecretKey { + public func generateSecretKey() throws -> SecretKey where Scheme.Context == Self { try Scheme.generateSecretKey(context: self) } @@ -1377,10 +1493,10 @@ extension Context { /// - Throws: Error upon failure to generate an evaluation key. /// - seealso: ``HeScheme/generateEvaluationKey(context:config:using:)`` for an alternative API. @inlinable - public func generateEvaluationKey( + public func generateEvaluationKey( config: EvaluationKeyConfig, using secretKey: SecretKey) throws - -> EvaluationKey + -> EvaluationKey where Scheme.Context == Self { try Scheme.generateEvaluationKey(context: self, config: config, using: secretKey) } diff --git a/Sources/HomomorphicEncryption/HeSchemeAsync.swift b/Sources/HomomorphicEncryption/HeSchemeAsync.swift index c220fbfe..8ef5ae44 100644 --- a/Sources/HomomorphicEncryption/HeSchemeAsync.swift +++ b/Sources/HomomorphicEncryption/HeSchemeAsync.swift @@ -206,13 +206,13 @@ extension HeScheme { } @inlinable - public static func forwardNttAsync(_ ciphertext: CoeffCiphertext) async throws -> EvalCiphertext { - try forwardNtt(ciphertext) + public static func forwardNttAsync(_ ciphertext: inout CoeffCiphertext) async throws -> EvalCiphertext { + try forwardNtt(&ciphertext) } @inlinable - public static func inverseNttAsync(_ ciphertext: EvalCiphertext) async throws -> CoeffCiphertext { - try inverseNtt(ciphertext) + public static func inverseNttAsync(_ ciphertext: inout EvalCiphertext) async throws -> CoeffCiphertext { + try inverseNtt(&ciphertext) } @inlinable diff --git a/Sources/HomomorphicEncryption/Keys.swift b/Sources/HomomorphicEncryption/Keys.swift index fab26f82..1b7c05cb 100644 --- a/Sources/HomomorphicEncryption/Keys.swift +++ b/Sources/HomomorphicEncryption/Keys.swift @@ -21,6 +21,18 @@ public final class SecretKey: Equatable, @unchecked Sendable { @usableFromInline var poly: PolyRq + /// public access to poly. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + public var _poly: PolyRq { poly } + + /// Create a secret key by providing its content. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + /// - Parameter _poly: the polynomial for the secret key. + @inlinable + public convenience init(_poly: consuming PolyRq) { + self.init(poly: _poly) + } + @inlinable init(poly: consuming PolyRq) { self.poly = poly @@ -49,65 +61,103 @@ extension SecretKey: PolyCollection { /// Key-switching operations include relinearization and Galois transformations. /// - seealso: ``HeScheme/relinearize(_:using:)`` and ``HeScheme/applyGalois(ciphertext:element:using:)`` for more /// details. -@usableFromInline -package struct KeySwitchKey: Equatable, Sendable { +public struct _KeySwitchKey: HeKeySwitchKey { /// The context used for key-switching operations. - @usableFromInline let context: Context + @usableFromInline let context: Scheme.Context /// The ciphertexts of the key-switching key. - @usableFromInline let ciphers: [Ciphertext] + @usableFromInline let ciphertexts: [Ciphertext] + + /// public access to context. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + public var _context: Scheme.Context { context } + /// public access to ciphertexts. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + public var _ciphertexts: [Ciphertext] { ciphertexts } + /// Create a key-switching key by providing its ontent. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + /// - Parameters: + /// - _context: the context of key switching key. + /// - _ciphertexts: the ciphertexts of key switching key. @inlinable - init(context: Context, ciphers: [Ciphertext]) { + public init(_context: Scheme.Context, _ciphertexts: [Ciphertext]) { + self.init(context: _context, ciphertexts: _ciphertexts) + } + + @inlinable + init(context: Scheme.Context, ciphertexts: [Ciphertext]) { self.context = context - self.ciphers = ciphers + self.ciphertexts = ciphertexts } } -extension KeySwitchKey: PolyCollection { +extension _KeySwitchKey: PolyCollection { public typealias Scalar = Scheme.Scalar @inlinable public func polyContext() -> PolyContext { - ciphers[0].polyContext() + ciphertexts[0].polyContext() } } -@usableFromInline -package struct RelinearizationKey: Equatable, Sendable { - @usableFromInline let keySwitchKey: KeySwitchKey +/// A cryptographic key used for relinearization operations. +public struct _RelinearizationKey: Equatable, Sendable { + @usableFromInline let keySwitchKey: Scheme.KeySwitchKey + /// public access to key-switching key. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + public var _keySwitchKey: Scheme.KeySwitchKey { keySwitchKey } + + /// Create a relinearization key by providing its content. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + /// - Parameter _keySwitchKey: the key-switching key for relinearization key. + @inlinable + public init(_keySwitchKey: Scheme.KeySwitchKey) { + self.init(keySwitchKey: _keySwitchKey) + } @inlinable - init(keySwitchKey: KeySwitchKey) { + init(keySwitchKey: Scheme.KeySwitchKey) { self.keySwitchKey = keySwitchKey } } -extension RelinearizationKey: PolyCollection { +extension _RelinearizationKey: PolyCollection { public typealias Scalar = Scheme.Scalar @inlinable public func polyContext() -> PolyContext { - keySwitchKey.ciphers[0].polyContext() + keySwitchKey.ciphertexts[0].polyContext() } } -@usableFromInline -package struct GaloisKey: Equatable, Sendable { - @usableFromInline package let keys: [Int: KeySwitchKey] +/// A cryptographic key used for ciphertext rotation operation. +public struct _GaloisKey: HeGaloisKey { + @usableFromInline let keys: [Int: Scheme.KeySwitchKey] + /// public access to key-switching keys. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + public var _keys: [Int: Scheme.KeySwitchKey] { keys } + + /// Create a Galois key by providing its content. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + /// - Parameter _keys: the key-switching keys of Galois key. + @inlinable + public init(_keys: [Int: Scheme.KeySwitchKey]) { + self.init(keys: _keys) + } @inlinable - init(keys: [Int: KeySwitchKey]) { + init(keys: [Int: Scheme.KeySwitchKey]) { self.keys = keys } } -extension GaloisKey: PolyCollection { +extension _GaloisKey: PolyCollection { public typealias Scalar = Scheme.Scalar @inlinable public func polyContext() -> PolyContext { if let firstKey = keys.values.first { - firstKey.ciphers[0].polyContext() + firstKey.ciphertexts[0].polyContext() } else { preconditionFailure("Empty Galois key") } @@ -118,8 +168,15 @@ extension GaloisKey: PolyCollection { /// /// Associated with a ``SecretKey``. public struct EvaluationKey: Equatable, Sendable { - @usableFromInline package let galoisKey: GaloisKey? - @usableFromInline package let relinearizationKey: RelinearizationKey? + @usableFromInline package let galoisKey: _GaloisKey? + @usableFromInline package let relinearizationKey: _RelinearizationKey? + + /// public access to Galois key. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + public var _galoisKey: _GaloisKey? { galoisKey } + /// public access to relineraization key. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + public var _relinearizationKey: _RelinearizationKey? { relinearizationKey } /// Returns the configuration for the evaluation key. public var config: EvaluationKeyConfig { @@ -128,11 +185,18 @@ public struct EvaluationKey: Equatable, Sendable { hasRelinearizationKey: relinearizationKey != nil) } + /// Create a evaluation key by providing its content. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + /// - Parameters: + /// - _galoisKey: the Galois key of the evaluation key. + /// - _relinearizationKey: the relinearization key of the evaluation key. + @inlinable + public init(_galoisKey: _GaloisKey?, _relinearizationKey: _RelinearizationKey?) { + self.init(galoisKey: _galoisKey, relinearizationKey: _relinearizationKey) + } + @inlinable - init( - galoisKey: GaloisKey?, - relinearizationKey: RelinearizationKey?) - { + init(galoisKey: _GaloisKey?, relinearizationKey: _RelinearizationKey?) { self.galoisKey = galoisKey self.relinearizationKey = relinearizationKey } diff --git a/Sources/HomomorphicEncryption/NoOpScheme.swift b/Sources/HomomorphicEncryption/NoOpScheme.swift index 0cb9cb73..e22dc14e 100644 --- a/Sources/HomomorphicEncryption/NoOpScheme.swift +++ b/Sources/HomomorphicEncryption/NoOpScheme.swift @@ -12,14 +12,39 @@ // See the License for the specific language governing permissions and // limitations under the License. +public struct EmptyAuxiliary: CiphertextAuxiliary, PlaintextAuxiliary { + public init( + context _: Scheme.Context, + polys _: [PolyRq], + correctionFactor _: Scheme.Scalar, + seed _: [UInt8]) throws {} + + public init( + context _: Scheme.Context, + poly _: PolyRq) {} + + public static func == (_: EmptyAuxiliary, _: EmptyAuxiliary) -> Bool { + true + } +} + /// This is a no-op scheme for development and testing. /// /// The scheme simply takes the plaintext as a "ciphertext" and /// ignores any ciphertext coefficient moduli. public enum NoOpScheme: HeScheme { + public typealias CiphertextAuxiliaryData = EmptyAuxiliary + public typealias PlaintextAuxiliaryData = EmptyAuxiliary + + public typealias Context = HomomorphicEncryption.Context + public typealias KeySwitchKey = HomomorphicEncryption._KeySwitchKey + public typealias GaloisKey = HomomorphicEncryption._GaloisKey + public typealias Scalar = UInt64 public typealias CanonicalCiphertextFormat = Coeff + public static var cryptosystem: HeCryptoSystem { .noOpScheme } + public static var freshCiphertextPolyCount: Int { 1 } @@ -28,22 +53,22 @@ public enum NoOpScheme: HeScheme { 0 } - public static func generateSecretKey(context: Context) -> SecretKey { + public static func generateSecretKey(context: Context) -> SecretKey { let poly = PolyRq.zero(context: context.secretKeyContext) return SecretKey(poly: poly) } public static func generateEvaluationKey( - context: Context, + context: Context, config: EvaluationKeyConfig, using _: SecretKey) throws -> EvaluationKey { - let keySwitchKey = KeySwitchKey(context: context, ciphers: []) - let galoisKeys = [Int: KeySwitchKey]( + let keySwitchKey = KeySwitchKey(context: context, ciphertexts: []) + let galoisKeys = [Int: KeySwitchKey]( config.galoisElements .map { g in (g, keySwitchKey) }) { first, _ in first } return EvaluationKey( galoisKey: GaloisKey(keys: galoisKeys), - relinearizationKey: RelinearizationKey(keySwitchKey: keySwitchKey)) + relinearizationKey: _RelinearizationKey(keySwitchKey: keySwitchKey)) } @inlinable @@ -56,19 +81,19 @@ public enum NoOpScheme: HeScheme { return SimdEncodingDimensions(rowCount: 2, columnCount: parameters.polyDegree / 2) } - public static func encode(context: Context, values: some Collection, + public static func encode(context: Context, values: some Collection, format: EncodeFormat) throws -> CoeffPlaintext { try context.encode(values: values, format: format) } - public static func encode(context: Context, signedValues: some Collection, + public static func encode(context: Context, signedValues: some Collection, format: EncodeFormat) throws -> CoeffPlaintext { try context.encode(signedValues: signedValues, format: format) } - public static func encode(context: Context, values: some Collection, + public static func encode(context: Context, values: some Collection, format: EncodeFormat, moduliCount _: Int?) throws -> EvalPlaintext { let coeffPlaintext = try Self.encode(context: context, values: values, format: format) @@ -76,7 +101,7 @@ public enum NoOpScheme: HeScheme { } public static func encode( - context: Context, + context: Context, signedValues: some Collection, format: EncodeFormat, moduliCount _: Int?) throws -> EvalPlaintext @@ -101,16 +126,20 @@ public enum NoOpScheme: HeScheme { try plaintext.inverseNtt().decode(format: format) } - public static func zeroCiphertextCoeff(context: Context, moduliCount _: Int?) throws -> CoeffCiphertext { - NoOpScheme + public static func skipLSBsForDecryption(for _: EncryptionParameters) -> [Int] { + Array(repeating: 0, count: freshCiphertextPolyCount) + } + + public static func zeroCiphertextCoeff(context: Context, moduliCount _: Int?) throws -> CoeffCiphertext { + try NoOpScheme .CoeffCiphertext( context: context, polys: [PolyRq.zero(context: context.plaintextContext)], correctionFactor: 1) } - public static func zeroCiphertextEval(context: Context, moduliCount _: Int?) throws -> EvalCiphertext { - NoOpScheme + public static func zeroCiphertextEval(context: Context, moduliCount _: Int?) throws -> EvalCiphertext { + try NoOpScheme .EvalCiphertext( context: context, polys: [PolyRq.zero(context: context.plaintextContext)], @@ -130,7 +159,7 @@ public enum NoOpScheme: HeScheme { public static func encrypt(_ plaintext: CoeffPlaintext, using _: SecretKey) throws -> CanonicalCiphertext { - NoOpScheme.CanonicalCiphertext( + try NoOpScheme.CanonicalCiphertext( context: plaintext.context, polys: [plaintext.poly], correctionFactor: 1) } @@ -138,7 +167,7 @@ public enum NoOpScheme: HeScheme { public static func decryptCoeff(_ ciphertext: CoeffCiphertext, using _: SecretKey) throws -> CoeffPlaintext { - NoOpScheme.CoeffPlaintext( + try NoOpScheme.CoeffPlaintext( context: ciphertext.context, poly: ciphertext.polys[0]) } @@ -303,19 +332,31 @@ public enum NoOpScheme: HeScheme { minNoiseBudget } - public static func forwardNtt(_ ciphertext: CoeffCiphertext) throws -> EvalCiphertext { + public static func forwardNtt(_ ciphertext: inout CoeffCiphertext) throws -> EvalCiphertext { let polys = try ciphertext.polys.map { try $0.forwardNtt() } - return Ciphertext(context: ciphertext.context, - polys: polys, - correctionFactor: ciphertext.correctionFactor, - seed: ciphertext.seed) + return try Ciphertext(context: ciphertext.context, + polys: polys, + correctionFactor: ciphertext.correctionFactor, + seed: ciphertext.seed) } - public static func inverseNtt(_ ciphertext: EvalCiphertext) throws -> CoeffCiphertext { + public static func inverseNtt(_ ciphertext: inout EvalCiphertext) throws -> CoeffCiphertext { let polys = try ciphertext.polys.map { try $0.inverseNtt() } - return Ciphertext(context: ciphertext.context, - polys: polys, - correctionFactor: ciphertext.correctionFactor, - seed: ciphertext.seed) + return try Ciphertext(context: ciphertext.context, + polys: polys, + correctionFactor: ciphertext.correctionFactor, + seed: ciphertext.seed) + } + + /// Returns the dimension counts for ``EncodeFormat/simd`` encoding, or `nil` if the HE scheme does + /// not support SIMD encoding for the given parameters. + @inlinable + public static func simdDimensions(for encryptionParameter: EncryptionParameters) + -> SimdEncodingDimensions? + { + guard encryptionParameter.supportsSimdEncoding else { + return nil + } + return SimdEncodingDimensions(rowCount: 2, columnCount: encryptionParameter.polyDegree / 2) } } diff --git a/Sources/HomomorphicEncryption/Plaintext.swift b/Sources/HomomorphicEncryption/Plaintext.swift index 9f324085..7ae20ca9 100644 --- a/Sources/HomomorphicEncryption/Plaintext.swift +++ b/Sources/HomomorphicEncryption/Plaintext.swift @@ -1,4 +1,4 @@ -// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,14 +19,27 @@ public struct Plaintext: Equatable, Sendab public typealias SignedScalar = Scheme.SignedScalar /// Context for HE computation. - public let context: Context + public let context: Scheme.Context @usableFromInline var poly: PolyRq + /// Public access to poly. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + public var _poly: PolyRq { poly } + + /// The auxiliary data is scheme-specific. + public var auxiliaryData: Scheme.PlaintextAuxiliaryData + + @inlinable + public init(_context: Scheme.Context, _poly: PolyRq) throws { + try self.init(context: _context, poly: _poly) + } + @inlinable - package init(context: Context, poly: PolyRq) { + package init(context: Scheme.Context, poly: PolyRq) throws { self.context = context self.poly = poly + self.auxiliaryData = try Scheme.PlaintextAuxiliaryData(context: context, poly: poly) } /// In-place plaintext addition: `lhs += rhs`. @@ -62,7 +75,7 @@ public struct Plaintext: Equatable, Sendab @inlinable public func forwardNtt() throws -> Plaintext where Format == Coeff { let poly = try poly.forwardNtt() - return Plaintext(context: context, poly: poly) + return try Plaintext(context: context, poly: poly) } /// Computes the inverse number-theoretic transform (NTT) on the plaintext. @@ -72,7 +85,7 @@ public struct Plaintext: Equatable, Sendab @inlinable public func inverseNtt() throws -> Plaintext where Format == Eval { let poly = try poly.inverseNtt() - return Plaintext(context: context, poly: poly) + return try Plaintext(context: context, poly: poly) } @inlinable @@ -136,7 +149,7 @@ extension Plaintext { return plaintext } let moduliCount = moduliCount ?? context.ciphertextContext.moduli.count - let rnsTool = context.getRnsTool(moduliCount: moduliCount) + let rnsTool = try context.getRnsTool(moduliCount: moduliCount) let polyContext = try context.ciphertextContext.getContext(moduliCount: moduliCount) var poly: PolyRq = PolyRq.zero(context: polyContext) @@ -162,7 +175,7 @@ extension Plaintext { if let plaintext = self as? Plaintext { return plaintext } - let rnsTool = context.getRnsTool(moduliCount: moduli.count) + let rnsTool = try context.getRnsTool(moduliCount: moduli.count) var plaintextData = try poly.convertToCoeffFormat().data for index in plaintextData.rowIndices(row: 0) { let condition = plaintextData[index].constantTimeGreaterThanOrEqual(rnsTool.tThreshold) @@ -175,7 +188,7 @@ extension Plaintext { let coeffPoly: PolyRq = PolyRq( context: context.plaintextContext, data: Array2d(array: plaintextData)) - return Plaintext(context: context, poly: coeffPoly) + return try Plaintext(context: context, poly: coeffPoly) } /// Decodes a plaintext. diff --git a/Sources/HomomorphicEncryption/PolyRq/Galois.swift b/Sources/HomomorphicEncryption/PolyRq/Galois.swift index 0d5a6dbf..cc0d2085 100644 --- a/Sources/HomomorphicEncryption/PolyRq/Galois.swift +++ b/Sources/HomomorphicEncryption/PolyRq/Galois.swift @@ -1,4 +1,4 @@ -// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -99,8 +99,8 @@ struct GaloisEvalIterator: IteratorProtocol { } extension FixedWidthInteger { - @inlinable - func isValidGaloisElement(for degree: Int) -> Bool { + /// Check if `self` is a valid Galois element for a group of `degree`. + public func isValidGaloisElement(for degree: Int) -> Bool { degree.isPowerOfTwo && !isMultiple(of: 2) && (self < (degree &<< 1)) && (self > 1) } } @@ -167,8 +167,8 @@ extension PolyRq where F == Eval { } } -@usableFromInline -enum GaloisElementGenerator { +/// The generator for Galois group. +public enum GaloisElementGenerator { @usableFromInline static let value: UInt32 = 3 } @@ -238,7 +238,7 @@ public enum GaloisElement { /// ``EncryptionParameters/polyDegree``. /// - Returns: Dictionary mapping Galois elements to their corresponding rotation steps. @inlinable - package static func stepsFor(elements: [Int], degree: Int) -> [Int: Int?] { + public static func stepsFor(elements: [Int], degree: Int) -> [Int: Int?] { var result: [Int: Int?] = Dictionary(elements.map { ($0, nil) }) { first, _ in first } @@ -271,7 +271,7 @@ public enum GaloisElement { /// - Returns: Dictionary mapping rotation steps to their counts, and `nil` if no plan was found. /// - Throws: Error upon invalid step or degree. @inlinable - package static func planMultiStep(supportedSteps: [Int], step: Int, degree: Int) throws -> [Int: Int]? { + public static func _planMultiStep(supportedSteps: [Int], step: Int, degree: Int) throws -> [Int: Int]? { guard abs(step) < degree else { throw HeError.invalidRotationStep(step: step, degree: degree) } diff --git a/Sources/HomomorphicEncryption/PolyRq/PolyContext.swift b/Sources/HomomorphicEncryption/PolyRq/PolyContext.swift index 33b8cd60..1b78344e 100644 --- a/Sources/HomomorphicEncryption/PolyRq/PolyContext.swift +++ b/Sources/HomomorphicEncryption/PolyRq/PolyContext.swift @@ -24,15 +24,15 @@ public final class PolyContext: Sendable { /// The modulus `Q = product_{i=0}^{L-1} q_i`, if representable by a `Width32` @usableFromInline let modulus: Width32? /// Next context, typically formed by dropping `q_{L-1}`. - @usableFromInline package let next: PolyContext? + public let next: PolyContext? /// Operations mod `q_0` up to `q_{L-1}`. - @usableFromInline let reduceModuli: [Modulus] + public let reduceModuli: [Modulus] /// Operations mod `UInt64(q_0), ..., UInt64(q_{L-1})`. - @usableFromInline let reduceModuliUInt64: [Modulus] + public let reduceModuliUInt64: [Modulus] /// Multiply by `q_{L-1}^{-1} mod q_i`, `mod q_i`. - @usableFromInline let inverseQLast: [MultiplyConstantModulus] + public let inverseQLast: [MultiplyConstantModulus] /// Precomputation for the NTT, for modulus `q_{L-1}`. - @usableFromInline let nttContext: NttContext? + public let nttContext: _NttContext? /// Initializes a ``PolyContext``. /// - Parameters: @@ -42,7 +42,7 @@ public final class PolyContext: Sendable { /// - nttContext: The NTT context for the last modulus `q_{L-1}`. /// - Throws: Error upon failure to initialize the context. @inlinable - required init(degree: Int, moduli: [T], next: PolyContext?, nttContext: NttContext? = nil) throws { + required init(degree: Int, moduli: [T], next: PolyContext?, nttContext: _NttContext? = nil) throws { guard degree.isPowerOfTwo else { throw HeError.invalidDegree(degree) } @@ -110,12 +110,12 @@ public final class PolyContext: Sendable { return MultiplyConstantModulus(multiplicand: inverse, modulus: modulus, variableTime: true) } if let nttContext { - precondition(nttContext.degree == degree, "Wrong degree in NttContext") - precondition(nttContext.modulus == qLast, "Wrong modulus in NttContext") + precondition(nttContext.degree == degree, "Wrong degree in _NttContext") + precondition(nttContext.modulus == qLast, "Wrong modulus in _NttContext") self.nttContext = nttContext } else { if !qLast.isPowerOfTwo, qLast.isNttModulus(for: degree) { - self.nttContext = try NttContext(degree: degree, modulus: qLast) + self.nttContext = try _NttContext(degree: degree, modulus: qLast) } else { self.nttContext = nil } @@ -148,7 +148,7 @@ public final class PolyContext: Sendable { /// - nttContexts: Maps moduli to their corresponding NTT context. /// - Throws: Error upon failure to initialize the context. @inlinable - convenience init(degree: Int, moduli: [T], child: PolyContext?, nttContexts: [T: NttContext] = [:]) throws { + convenience init(degree: Int, moduli: [T], child: PolyContext?, nttContexts: [T: _NttContext] = [:]) throws { guard let qLast = moduli.last else { throw HeError.emptyModulus } @@ -219,6 +219,12 @@ public final class PolyContext: Sendable { self == context || isParent(of: context) } + /// Returns the context with the right amount of moduli. + /// - Warning: This API is not subject to semantic versioning: these APIs may change without warning. + public func _getContext(moduliCount: Int) throws -> PolyContext { + try getContext(moduliCount: moduliCount) + } + @inlinable func getContext(moduliCount: Int) throws -> PolyContext { precondition(moduliCount > 0 && moduliCount <= moduli.count, "Invalid number of moduli") diff --git a/Sources/HomomorphicEncryption/PolyRq/PolyRq+Ntt.swift b/Sources/HomomorphicEncryption/PolyRq/PolyRq+Ntt.swift index de7da726..b664082d 100644 --- a/Sources/HomomorphicEncryption/PolyRq/PolyRq+Ntt.swift +++ b/Sources/HomomorphicEncryption/PolyRq/PolyRq+Ntt.swift @@ -105,16 +105,14 @@ extension ScalarType { } } -@usableFromInline -struct NttContext: Sendable { - @usableFromInline let rootOfUnityPowers: MultiplyConstantArrayModulus - @usableFromInline let inverseRootOfUnityPowers: MultiplyConstantArrayModulus - /// `degree^{-1} mod modulus`. - @usableFromInline let inverseDegree: MultiplyConstantModulus - /// `(degree)^{-1} * w^{-N} mod modulus` for `w` a root of unity mod modulus. - @usableFromInline let inverseDegreeRootOfUnity: MultiplyConstantModulus - @usableFromInline let degree: Int - @usableFromInline let modulus: T +public struct _NttContext: Sendable { + public let rootOfUnityPowers: MultiplyConstantArrayModulus + public let inverseRootOfUnityPowers: MultiplyConstantArrayModulus + public let inverseDegree: MultiplyConstantModulus // degree^{-1} mod modulus + // (degree)^{-1} * w^{-N} mod modulus for `w` a root of unity mod modulus + public let inverseDegreeRootOfUnity: MultiplyConstantModulus + public let degree: Int + public let modulus: T @inlinable init(degree: Int, modulus: T) throws { @@ -137,7 +135,6 @@ struct NttContext: Sendable { inverseRootOfUnityPowers[previousIdx]) previousIdx = reverseIdx } - self.degree = degree self.modulus = modulus self.rootOfUnityPowers = MultiplyConstantArrayModulus( @@ -203,46 +200,42 @@ func forwardButterfly( return (xOut, yOut) } -extension PolyRq where F == Coeff { - /// Performs the forward number-theoretic transform (NTT). +extension PolyContext { + /// Performs the forward number-theoretic transform (NTT) in this context. + /// - Parameter poly: the polynomial to run forward NTT on. /// - Returns: The ``Eval`` representation of the polynomial. /// - Throws: Error upon failure to compute the forward NTT. @inlinable - public consuming func forwardNtt() throws -> PolyRq { - try context.validateNttModuli() - var currentContext: PolyContext? = context + func forwardNtt(poly: consuming PolyRq) throws -> PolyRq { + assert(self === poly.context || self == poly.context) + try validateNttModuli() + var currentContext: PolyContext? = self while let context = currentContext, let modulus = context.moduli.last { - let rowOffset = data.index(row: context.moduli.count - 1, column: 0) - try data.data.withUnsafeMutableBufferPointer { dataPtr in + let rowOffset = poly.data.index(row: context.moduli.count - 1, column: 0) + try poly.data.data.withUnsafeMutableBufferPointer { dataPtr in // swiftlint:disable:next force_unwrapping try context.forwardNtt(dataPtr: dataPtr.baseAddress! + rowOffset, modulus: modulus) } currentContext = context.next } - return PolyRq(context: context, data: data) + return PolyRq(context: self, data: poly.data) } } -extension PolyContext { - /// Performs the forward number-theoretic transform (NTT) on a single modulus. - /// - Parameters: - /// - dataPtr: Pointer to the coefficients mod `modulus`. - /// - modulus: Modulus. +extension PolyRq where F == Coeff { + /// Performs the forward number-theoretic transform (NTT). + /// - Returns: The ``Eval`` representation of the polynomial. /// - Throws: Error upon failure to compute the forward NTT. @inlinable - func forwardNtt(dataPtr: UnsafeMutablePointer, modulus: T) throws { - // We modify Harvey's approach with delayed modular reduction. - var context = self - while modulus != context.moduli.last, let nextContext = context.next { - context = nextContext - } - guard modulus == context.moduli.last else { - throw HeError.invalidPolyContext(context) - } - guard let nttContext = context.nttContext, let modulusReduceFactor = context.reduceModuli.last - else { - throw HeError.invalidPolyContext(context) - } + public consuming func forwardNtt() throws -> PolyRq { + try context.forwardNtt(poly: self) + } +} + +extension _NttContext { + @inlinable + func forwardNtt(dataPtr: UnsafeMutablePointer, modulus: T, modulusReduceFactor: Modulus, degree: Int) throws { + let nttContext = self let n = degree let twiceModulus = modulus << 1 @@ -326,6 +319,34 @@ extension PolyContext { } } +extension PolyContext { + /// Performs the forward number-theoretic transform (NTT) on a single modulus. + /// - Parameters: + /// - dataPtr: Pointer to the coefficients mod `modulus`. + /// - modulus: Modulus. + /// - Throws: Error upon failure to compute the forward NTT. + @inlinable + func forwardNtt(dataPtr: UnsafeMutablePointer, modulus: T) throws { + // We modify Harvey's approach with delayed modular reduction. + var context = self + while modulus != context.moduli.last, let nextContext = context.next { + context = nextContext + } + guard modulus == context.moduli.last else { + throw HeError.invalidPolyContext(context) + } + guard let nttContext = context.nttContext, let modulusReduceFactor = context.reduceModuli.last + else { + throw HeError.invalidPolyContext(context) + } + try nttContext.forwardNtt( + dataPtr: dataPtr, + modulus: modulus, + modulusReduceFactor: modulusReduceFactor, + degree: degree) + } +} + /// Computes a lazy inverse NTT butterfly. /// - Parameters: /// - x: In `[0, kModulus)`. @@ -353,142 +374,169 @@ func inverseButterfly( return (x, y) } -extension PolyRq where F == Eval { - /// Performs the inverse number-theoretic transform (NTT). - /// - Returns: The ``Coeff`` representation of the polynomial. - /// - Throws: Error upon failure to compute the inverse NTT. - @inlinable - public consuming func inverseNtt() throws -> PolyRq { - try context.validateNttModuli() - var currentContext: PolyContext? = context - while let context = currentContext { - try inverseNtt(using: context) - currentContext = context.next - } - return PolyRq(context: context, data: data) - } - - /// Computes the inverse number-theoretic transform (NTT) on the last modulus in the context. - /// - Parameter context: Context whose last modulus to use for the NTT. - /// - Throws: Error upon failure to compute the inverse NTT. +extension _NttContext { @inlinable - mutating func inverseNtt(using context: PolyContext) throws { - // We modify Harvey's approach with delayed modular reduction. - let moduli = context.moduli - guard let modulus = moduli.last else { - throw HeError.emptyModulus - } - let rnsIndex = moduli.count &- 1 - let n = degree - - let rowOffset = data.rowIndices(row: rnsIndex).first - guard let rowOffset, let nttContext = context.nttContext - else { - throw HeError.invalidPolyContext(context) - } - let inverseRootOfUnityPowers = nttContext.inverseRootOfUnityPowers - let inverseDegree = nttContext.inverseDegree - let inverseDegreeRootOfUnity = nttContext.inverseDegreeRootOfUnity - + func inverseNtt( + dataPtr: UnsafeMutablePointer, + modulus: T, + reduceModulus: Modulus, + degree: Int, + rowOffset: Int) + { let modulusMultiplesCount = min(degree.log2 &+ 1, modulus.leadingZeroBitCount) - let reduceModulus = context.reduceModuli[rnsIndex] - var rootIdx = 1 var lazyReductionCounter = -1 - let nDiv2 = n &>> 1 - // swiftlint:disable:next closure_body_length - data.data.withUnsafeMutableBufferPointer { dataPtr in - // swiftlint:disable:next force_unwrapping - let dataPtr = dataPtr.baseAddress! + rowOffset + let nDiv2 = degree &>> 1 - for log2m in (0..> (log2m &+ 1) - lazyReductionCounter &+= 1 - let timeToReduce = lazyReductionCounter == modulusMultiplesCount - if timeToReduce { - if m == 1 { - lazyReductionCounter &-= 1 - } else { - lazyReductionCounter = 0 - } - } - let kTimesModulus = modulus &<< lazyReductionCounter + let dataPtr = dataPtr + rowOffset + for log2m in (0..> (log2m &+ 1) + lazyReductionCounter &+= 1 + let timeToReduce = lazyReductionCounter == modulusMultiplesCount + if timeToReduce { if m == 1 { - // Final stage, folding in multiplication by n^{-1} and modular reduction - func applyOp(_ op: (_ x: inout T, _ y: inout T) -> Void) { - for xIdx in 0.. Void) { + for xIdx in 0.. Void) { + for i in 0.. Void) { - for i in 0.. Void) { + for i in 0.. Void) { - for i in 0..) throws { + // We modify Harvey's approach with delayed modular reduction. + guard let modulus = moduli.last else { + throw HeError.emptyModulus + } + let rnsIndex = moduli.count &- 1 + let n = data.columnCount + + let rowOffset = data.rowIndices(row: rnsIndex).first + guard let rowOffset, let nttContext + else { + throw HeError.invalidPolyContext(self) + } + + let reduceModulus = reduceModuli[rnsIndex] + + data.data.withUnsafeMutableBufferPointer { dataPtr in + // swiftlint:disable:next force_unwrapping + let dataPtr = dataPtr.baseAddress! + nttContext.inverseNtt( + dataPtr: dataPtr, + modulus: modulus, + reduceModulus: reduceModulus, + degree: n, + rowOffset: rowOffset) + } + } + + /// Performs the inverse number-theoretic transform (NTT) in this context. + /// - Parameter poly: the polynomial to run inverse NTT on. + /// - Returns: The ``Coeff`` representation of the polynomial. + /// - Throws: Error upon failure to compute the inverse NTT. + @inlinable + func inverseNtt(poly: consuming PolyRq) throws -> PolyRq { + assert(self === poly.context || self == poly.context) + try validateNttModuli() + var currentContext: PolyContext? = self + while let context = currentContext { + try context.inverseNtt(data: &poly.data) + currentContext = context.next + } + return PolyRq(context: self, data: poly.data) + } +} + +extension PolyRq where F == Eval { + /// Performs the inverse number-theoretic transform (NTT). + /// - Returns: The ``Coeff`` representation of the polynomial. + /// - Throws: Error upon failure to compute the inverse NTT. + @inlinable + public consuming func inverseNtt() throws -> PolyRq { + try context.inverseNtt(poly: self) + } +} diff --git a/Sources/HomomorphicEncryption/RnsBaseConverter.swift b/Sources/HomomorphicEncryption/RnsBaseConverter.swift index 74b032b9..bc3d9ed9 100644 --- a/Sources/HomomorphicEncryption/RnsBaseConverter.swift +++ b/Sources/HomomorphicEncryption/RnsBaseConverter.swift @@ -1,4 +1,4 @@ -// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,17 +16,16 @@ import ModularArithmetic /// Enables base conversion from an input RNS basis `q = q_0, ..., q_{L-1}` to an /// output RNS basis `t = t_0, ..., t_{M-1}`. -@usableFromInline -struct RnsBaseConverter: Sendable { +public struct _RnsBaseConverter: Sendable { /// `q_0, ..., q_{L-1}`. - @usableFromInline let inputContext: PolyContext + public let inputContext: PolyContext /// `t_0, ..., t_{M-1}``. - @usableFromInline let outputContext: PolyContext + public let outputContext: PolyContext /// (i, j)'th entry stores `(q / q_i) % t_j`. - @usableFromInline let puncturedProducts: Array2d + public let puncturedProducts: Array2d /// Composes polynomials with `inputContext`. - @usableFromInline let crtComposer: CrtComposer + public let crtComposer: _CrtComposer /// i'th entry stores `(q_i / q) % q_i`. @usableFromInline var inversePuncturedProducts: [MultiplyConstantModulus] { diff --git a/Sources/HomomorphicEncryption/RnsTool.swift b/Sources/HomomorphicEncryption/RnsTool.swift index 238e339f..e0cfd15f 100644 --- a/Sources/HomomorphicEncryption/RnsTool.swift +++ b/Sources/HomomorphicEncryption/RnsTool.swift @@ -14,14 +14,14 @@ import ModularArithmetic -@usableFromInline -package struct RnsTool: Sendable { +/// Helpers for operation under RNS. +public struct _RnsTool: Sendable { @usableFromInline struct RnsToolContext { @usableFromInline let inputContext: PolyContext @usableFromInline let bSkMTildeContext: PolyContext @usableFromInline let mSkContext: PolyContext - @usableFromInline let nttContexts: [T: NttContext] + @usableFromInline let nttContexts: [T: _NttContext] @usableFromInline let tGammaContext: PolyContext @inlinable @@ -42,7 +42,7 @@ package struct RnsTool: Sendable { throw HeError.emptyModulus } - var nttContexts = [T: NttContext]() + var nttContexts = [T: _NttContext]() for moduliCount in 1..: Sendable { moduli: [outputContext.moduli[0], T.rnsCorrectionFactor], child: outputContext) } - func getbSkMTildeNttContext(moduliCount: Int) throws -> NttContext? { + func getbSkMTilde_NttContext(moduliCount: Int) throws -> _NttContext? { try bSkMTildeContext.getContext(moduliCount: moduliCount).nttContext } } /// `Q = q_0, ..., q_{L-1}`. - @usableFromInline let inputContext: PolyContext + public let inputContext: PolyContext /// `t_0, ..., t_{M-1}`. - @usableFromInline let outputContext: PolyContext + public let outputContext: PolyContext /// `[q, B_sk]`. - @usableFromInline let qBskContext: PolyContext + public let qBskContext: PolyContext /// `Q mod t_0`. - @usableFromInline let qModT: T + public let qModT: T /// reduction by `t_0`. - @usableFromInline let t: Modulus + public let t: Modulus /// Multiplication by `gamma^{-1}` mod `t`, mod `t`. - @usableFromInline let inverseGammaModT: MultiplyConstantModulus + public let inverseGammaModT: MultiplyConstantModulus /// Multiplication by `-(Q^{-1})` mod `m_tilde`, mod `t`. - @usableFromInline let negInverseQModMTilde: MultiplyConstantModulus + public let negInverseQModMTilde: MultiplyConstantModulus /// Multiplication by `B^{-1} mod m_sk`, mod `m_sk`. - @usableFromInline let inverseBModMSk: MultiplyConstantModulus + public let inverseBModMSk: MultiplyConstantModulus /// i'th entry stores `q_i - t_0`. - @usableFromInline let tIncrement: [T] + public let tIncrement: [T] /// i'th entry stores `\tilde{m} mod qi`. - @usableFromInline let mTildeModQ: [T] + public let mTildeModQ: [T] /// `-(Q^{-1}) mod {t, gamma}`. - @usableFromInline let negInverseQModTGamma: [T] + public let negInverseQModTGamma: [T] /// `|gamma * t|_qi`. - @usableFromInline let prodGammaTModQ: [T] + public let prodGammaTModQ: [T] /// Multiplication by `m_tilde^{-1} mod B_sk`, mod `B_sk`. - @usableFromInline let inverseMTildeModBSk: [MultiplyConstantModulus] + public let inverseMTildeModBSk: [MultiplyConstantModulus] /// Multiplication by `Q^{-1} mod B_sk`, mod `B_sk`. - @usableFromInline let inverseQModBSk: [MultiplyConstantModulus] + public let inverseQModBSk: [MultiplyConstantModulus] /// i'th entry stores modulus for multiplication by `floor(Q / t_0) % q_i`, mod `q_i` /// Also called `delta` in the literature. - @usableFromInline let qDivT: [MultiplyConstantModulus] + public let qDivT: [MultiplyConstantModulus] /// Multiplication by `Q mod B_sk`, mod `B_sk`. - @usableFromInline let qModBSk: [MultiplyConstantModulus] + public let qModBSk: [MultiplyConstantModulus] /// Multiplication by `-B mod q_i`, mod `q_i`. - @usableFromInline let negBModQ: [MultiplyConstantModulus] + public let negBModQ: [MultiplyConstantModulus] /// Multiplication by `B mod q_i`, mod `q_i`. - @usableFromInline let bModQ: [MultiplyConstantModulus] + public let bModQ: [MultiplyConstantModulus] /// Base conversion from `Q` to `B_sk`. - @usableFromInline let rnsConvertQToBSk: RnsBaseConverter + public let rnsConvertQToBSk: _RnsBaseConverter /// Base conversion from `B` to `M_sk`. - @usableFromInline let rnsConvertBtoMSk: RnsBaseConverter + public let rnsConvertBtoMSk: _RnsBaseConverter /// Base conversion from `B` to `Q`. - @usableFromInline let rnsConvertBtoQ: RnsBaseConverter + public let rnsConvertBtoQ: _RnsBaseConverter /// Base conversion matrix from `Q` to `[B_sk, m_tilde]`, where /// `B` is an auxiliary base, `m_sk` is an extra modulus, and the /// `B_sk = [B, m_sk]` is an extended base. - @usableFromInline let rnsConvertQToBSkMTilde: RnsBaseConverter + public let rnsConvertQToBSkMTilde: _RnsBaseConverter /// Base conversion from `Q` to `[t, gamma]`. - @usableFromInline let rnsConvertQToTGamma: RnsBaseConverter + public let rnsConvertQToTGamma: _RnsBaseConverter @inlinable var tThreshold: T { (outputContext.moduli[0] + 1) / 2 @@ -150,7 +150,7 @@ package struct RnsTool: Sendable { variableTime: true) let tGammaContext = rnsToolContext.tGammaContext - self.rnsConvertQToTGamma = try RnsBaseConverter(from: inputContext, to: tGammaContext) + self.rnsConvertQToTGamma = try _RnsBaseConverter(from: inputContext, to: tGammaContext) self.negInverseQModTGamma = try tGammaContext.reduceModuli.map { modulus in let qMod = inputContext.qRemainder(dividingBy: modulus) return try qMod.inverseMod(modulus: modulus.modulus, variableTime: true).negateMod(modulus: modulus.modulus) @@ -244,17 +244,17 @@ package struct RnsTool: Sendable { let multiplicand = try bModMSk.inverseMod(modulus: mSk, variableTime: true) return MultiplyConstantModulus(multiplicand: multiplicand, modulus: mSk, variableTime: true) }() - self.rnsConvertQToBSk = try RnsBaseConverter(from: inputContext, to: bSkContext) - self.rnsConvertQToBSkMTilde = try RnsBaseConverter(from: inputContext, to: bSkMTildeContext) - self.rnsConvertBtoMSk = try RnsBaseConverter(from: bContext, to: mSkContext) - self.rnsConvertBtoQ = try RnsBaseConverter(from: bContext, to: inputContext) + self.rnsConvertQToBSk = try _RnsBaseConverter(from: inputContext, to: bSkContext) + self.rnsConvertQToBSkMTilde = try _RnsBaseConverter(from: inputContext, to: bSkMTildeContext) + self.rnsConvertBtoMSk = try _RnsBaseConverter(from: bContext, to: mSkContext) + self.rnsConvertBtoQ = try _RnsBaseConverter(from: bContext, to: inputContext) } @inlinable init(from inputContext: PolyContext, to outputContext: PolyContext) throws { - let rnsToolContext = try RnsTool.RnsToolContext(inputContext: inputContext, outputContext: outputContext) + let rnsToolContext = try _RnsTool.RnsToolContext(inputContext: inputContext, outputContext: outputContext) try self.init(from: inputContext, to: outputContext, rnsToolContext: rnsToolContext) } diff --git a/Sources/HomomorphicEncryption/SerializedCiphertext.swift b/Sources/HomomorphicEncryption/SerializedCiphertext.swift index b54f2ccc..574e84fb 100644 --- a/Sources/HomomorphicEncryption/SerializedCiphertext.swift +++ b/Sources/HomomorphicEncryption/SerializedCiphertext.swift @@ -43,7 +43,7 @@ extension Ciphertext { @inlinable public init( deserialize serialized: SerializedCiphertext, - context: Context, + context: Scheme.Context, moduliCount: Int? = nil) throws { self.context = context @@ -51,9 +51,9 @@ extension Ciphertext { let polyContext = try context.secretKeyContext.getContext(moduliCount: moduliCount) switch serialized { case let .seeded(poly0: poly0, seed: seed): - let poly = try PolyRq<_, Format>(deserialize: poly0, context: polyContext) + let poly = try PolyRq(deserialize: poly0, context: polyContext) var rng = try NistAes128Ctr(seed: seed) - let a = PolyRq<_, Eval>.random(context: polyContext, using: &rng) + let a = PolyRq.random(context: polyContext, using: &rng) let poly1: PolyRq = try a.convertFormat() self.polys = [poly, poly1] self.correctionFactor = 1 @@ -62,6 +62,11 @@ extension Ciphertext { self.polys = try Serialize.deserializePolys(from: polys, context: polyContext, skipLSBs: skipLSBs) self.correctionFactor = correctionFactor } + self.auxiliaryData = try Scheme.CiphertextAuxiliaryData( + context: context, + polys: polys, + correctionFactor: correctionFactor, + seed: seed) } /// Serializes a ciphertext, retaining decryption correctness only at the given indices. diff --git a/Sources/HomomorphicEncryption/SerializedKeys.swift b/Sources/HomomorphicEncryption/SerializedKeys.swift index a3bda53a..24eae15b 100644 --- a/Sources/HomomorphicEncryption/SerializedKeys.swift +++ b/Sources/HomomorphicEncryption/SerializedKeys.swift @@ -30,7 +30,7 @@ extension SecretKey { /// - serialized: Serialized secret key. /// - context: Context to associate with the secret key. /// - Throws: ``HeError`` upon failure to deserialize. - public convenience init(deserialize serialized: SerializedSecretKey, context: Context) throws { + public convenience init(deserialize serialized: SerializedSecretKey, context: Scheme.Context) throws { let polys: [PolyRq] = try Serialize.deserializePolys( from: serialized.polys, context: context.secretKeyContext) @@ -50,20 +50,20 @@ extension SecretKey { } } -extension KeySwitchKey { +extension HeKeySwitchKey { @inlinable - init(deserialize ciphertexts: [SerializedCiphertext], context: Context) throws { - self.context = context - self.ciphers = try ciphertexts.map { serializedCiphertext in - try Ciphertext( + init(deserialize ciphertexts: [SerializedCiphertext], context: Scheme.Context) throws { + let ciphertexts = try ciphertexts.map { serializedCiphertext in + try Ciphertext( deserialize: serializedCiphertext, context: context, moduliCount: context.secretKeyContext.moduli.count) } + try self.init(_context: context, _ciphertexts: ciphertexts) } - func serialize() -> [SerializedCiphertext] { - ciphers.map { $0.serialize() } + func serialize() -> [SerializedCiphertext] { + ciphertexts.map { $0.serialize() } } } @@ -79,15 +79,16 @@ public struct SerializedGaloisKey: Hashable, Codable, Sendab } } -extension GaloisKey { +extension HeGaloisKey { @inlinable - init(deserialize serialized: SerializedGaloisKey, context: Context) throws { - self.keys = try serialized.galoisKey.mapValues { serializedKeySwitchKey in - try KeySwitchKey(deserialize: serializedKeySwitchKey, context: context) + init(deserialize serialized: SerializedGaloisKey, context: Scheme.Context) throws { + let keys = try serialized.galoisKey.mapValues { serializedKeySwitchKey in + try Scheme.KeySwitchKey(deserialize: serializedKeySwitchKey, context: context) } + try self.init(_keys: keys) } - func serialize() -> SerializedGaloisKey { + func serialize() -> SerializedGaloisKey { SerializedGaloisKey(galoisKey: keys.mapValues { $0.serialize() }) } } @@ -104,10 +105,10 @@ public struct SerializedRelinearizationKey: Hashable, Codabl } } -extension RelinearizationKey { +extension _RelinearizationKey { @inlinable - init(deserialize serialized: SerializedRelinearizationKey, context: Context) throws { - self.keySwitchKey = try KeySwitchKey(deserialize: serialized.relinKey, context: context) + init(deserialize serialized: SerializedRelinearizationKey, context: Scheme.Context) throws { + self.keySwitchKey = try Scheme.KeySwitchKey(deserialize: serialized.relinKey, context: context) } func serialize() -> SerializedRelinearizationKey { @@ -139,12 +140,14 @@ extension EvaluationKey { /// - context: Context to associate with the evaluation key. /// - Throws: ``HeError`` upon failure to deserialize. @inlinable - public init(deserialize serialized: SerializedEvaluationKey, context: Context) throws { + public init(deserialize serialized: SerializedEvaluationKey, + context: Scheme.Context) throws + { self.galoisKey = try serialized.galoisKey.map { serialized in - try GaloisKey(deserialize: serialized, context: context) + try _GaloisKey(deserialize: serialized, context: context) } self.relinearizationKey = try serialized.relinearizationKey.map { serialized in - try RelinearizationKey(deserialize: serialized, context: context) + try _RelinearizationKey(deserialize: serialized, context: context) } } diff --git a/Sources/HomomorphicEncryption/SerializedPlaintext.swift b/Sources/HomomorphicEncryption/SerializedPlaintext.swift index a3f86eba..b88a0a12 100644 --- a/Sources/HomomorphicEncryption/SerializedPlaintext.swift +++ b/Sources/HomomorphicEncryption/SerializedPlaintext.swift @@ -1,4 +1,4 @@ -// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -38,9 +38,10 @@ extension Plaintext where Format == Coeff { /// - context: Context to associate with the plaintext. /// - Throws: Error upon failure to deserialize. @inlinable - public init(deserialize serialized: SerializedPlaintext, context: Context) throws { + public init(deserialize serialized: SerializedPlaintext, context: Scheme.Context) throws { self.context = context self.poly = try PolyRq(deserialize: serialized.poly, context: context.plaintextContext) + self.auxiliaryData = try Scheme.PlaintextAuxiliaryData(context: context, poly: poly) } } @@ -53,10 +54,15 @@ extension Plaintext where Format == Eval { /// the top-level ciphertext context with all the moduli. /// - Throws: Error upon failure to deserialize. @inlinable - public init(deserialize serialized: SerializedPlaintext, context: Context, moduliCount: Int? = nil) throws { + public init( + deserialize serialized: SerializedPlaintext, + context: Scheme.Context, + moduliCount: Int? = nil) throws + { self.context = context let moduliCount = moduliCount ?? context.ciphertextContext.moduli.count let plaintextContext = try context.ciphertextContext.getContext(moduliCount: moduliCount) self.poly = try PolyRq(deserialize: serialized.poly, context: plaintextContext) + self.auxiliaryData = try Scheme.PlaintextAuxiliaryData(context: context, poly: poly) } } diff --git a/Sources/HomomorphicEncryptionProtobuf/ConversionHe.swift b/Sources/HomomorphicEncryptionProtobuf/ConversionHe.swift index 8c6eff75..1f5bec6e 100644 --- a/Sources/HomomorphicEncryptionProtobuf/ConversionHe.swift +++ b/Sources/HomomorphicEncryptionProtobuf/ConversionHe.swift @@ -216,7 +216,7 @@ extension Apple_SwiftHomomorphicEncryption_V1_SerializedEvaluationKey { /// - Parameter context: Context to associate with the native object. /// - Returns: The converted native type. /// - Throws: Error upon upon invalid object. - public func native(context: Context) throws -> EvaluationKey { + public func native(context: Scheme.Context) throws -> EvaluationKey { let serialized: SerializedEvaluationKey = try native() return try EvaluationKey(deserialize: serialized, context: context) } @@ -364,6 +364,7 @@ extension Apple_SwiftHomomorphicEncryption_V1_EncryptionParameters { extension EncryptionParameters { /// Converts the native object into a protobuf object. + /// - Parameter scheme: The HE scheme to use. /// - Returns: The converted protobuf object. /// - Throws: Error upon unsupported object. public func proto(scheme: (some HeScheme).Type) throws -> Apple_SwiftHomomorphicEncryption_V1_EncryptionParameters { diff --git a/Sources/ModularArithmetic/Modulus.swift b/Sources/ModularArithmetic/Modulus.swift index 64689168..33cd082f 100644 --- a/Sources/ModularArithmetic/Modulus.swift +++ b/Sources/ModularArithmetic/Modulus.swift @@ -417,9 +417,12 @@ public struct MultiplyConstantModulus: Equatable, Sendable { /// A modulus for multiplication by an array of constants. public struct MultiplyConstantArrayModulus: Equatable, Sendable { - @usableFromInline let multiplicands: [T] - @usableFromInline let factors: [T] - @usableFromInline let modulus: T + /// Array of numbers to be used as multiplicands. + public let multiplicands: [T] + /// Array of Barrett factors. + public let factors: [T] + /// Associated modulus. + public let modulus: T @inlinable public init(multiplicands: [T], factors: [T], modulus: T) { diff --git a/Sources/PIRProcessDatabase/main.swift b/Sources/PIRProcessDatabase/main.swift index 4e5a419f..e2e566aa 100644 --- a/Sources/PIRProcessDatabase/main.swift +++ b/Sources/PIRProcessDatabase/main.swift @@ -397,11 +397,12 @@ struct ProcessDatabase: AsyncParsableCommand { /// - scheme: The HE scheme. /// - Throws: Error upon processing the database. @inlinable - mutating func process(config: Arguments, scheme: Scheme.Type) async throws { + mutating func process(config: Arguments, pirUtil _: PirUtil.Type) async throws { + typealias Scalar = PirUtil.Scheme.Scalar let database: [KeywordValuePair] = try Apple_SwiftHomomorphicEncryption_Pir_V1_KeywordDatabase(from: config.inputDatabase).native() - let config = try config.resolve(for: database, scheme: scheme) + let config = try config.resolve(for: database, scheme: PirUtil.Scheme.self) ProcessDatabase.logger.info("Processing database with configuration: \(config)") let keywordConfig = try KeywordPirConfig(dimensionCount: 2, cuckooTableConfig: config.cuckooTableConfig, @@ -414,15 +415,15 @@ struct ProcessDatabase: AsyncParsableCommand { sharding: config.sharding, keywordPirConfig: keywordConfig) - let encryptionParameters = try EncryptionParameters(from: config.rlweParameters) - let processArgs = try ProcessKeywordDatabase.Arguments( - databaseConfig: databaseConfig, - encryptionParameters: encryptionParameters, - algorithm: config.algorithm, - keyCompression: config.keyCompression, - trialsPerShard: config.trialsPerShard, - symmetricPirConfig: config.symmetricPirConfig) - let context = try Context(encryptionParameters: processArgs.encryptionParameters) + let encryptionParameters = try EncryptionParameters(from: config.rlweParameters) + let processArgs = try ProcessKeywordDatabase.Arguments(databaseConfig: databaseConfig, + encryptionParameters: encryptionParameters, + algorithm: config.algorithm, + keyCompression: config.keyCompression, + trialsPerShard: config.trialsPerShard, + symmetricPirConfig: config.symmetricPirConfig) + let context = try PirUtil.Scheme.Context(encryptionParameters: processArgs.encryptionParameters) + let keywordDatabase = try KeywordDatabase( rows: database, sharding: processArgs.databaseConfig.sharding, @@ -441,7 +442,8 @@ struct ProcessDatabase: AsyncParsableCommand { shard: shard, config: config, context: context, - processArgs: processArgs) + processArgs: processArgs, + pirUtil: PirUtil.self) } } @@ -455,7 +457,8 @@ struct ProcessDatabase: AsyncParsableCommand { shardID: shardID, shard: shard, config: config, context: context, - processArgs: processArgs) + processArgs: processArgs, + pirUtil: PirUtil.self) evaluationKeyConfig = [evaluationKeyConfig, processedEvaluationKeyConfig].union() } } @@ -463,18 +466,20 @@ struct ProcessDatabase: AsyncParsableCommand { if let evaluationKeyConfigFile = config.outputEvaluationKeyConfig { let protoEvaluationKeyConfig = try evaluationKeyConfig.proto( encryptionParameters: encryptionParameters, - scheme: Scheme.self) + scheme: PirUtil.Scheme.self) try protoEvaluationKeyConfig.save(to: evaluationKeyConfigFile) ProcessDatabase.logger.info("Saved evaluation key configuration to \(evaluationKeyConfigFile)") } } - private func processShard( + // swiftlint:disable:next function_parameter_count + private func processShard( shardID: String, shard: KeywordDatabaseShard, config: ResolvedArguments, - context: Context, - processArgs: ProcessKeywordDatabase.Arguments) async throws -> EvaluationKeyConfig + context: PirUtil.Scheme.Context, + processArgs: ProcessKeywordDatabase.Arguments, + pirUtil _: PirUtil.Type) async throws -> EvaluationKeyConfig { var logger = ProcessDatabase.logger logger[metadataKey: "shardID"] = .string(shardID) @@ -501,9 +506,10 @@ struct ProcessDatabase: AsyncParsableCommand { } logger.info("Processing shard with \(shard.rows.count) rows") - let processed: ProcessedDatabaseWithParameters = try ProcessKeywordDatabase.processShard( + let processed: ProcessedDatabaseWithParameters = try await ProcessKeywordDatabase.processShard( shard: shard, with: processArgs, + using: PirUtil.self, onEvent: logEvent) if config.trialsPerShard > 0 { @@ -511,10 +517,10 @@ struct ProcessDatabase: AsyncParsableCommand { throw PirError.emptyDatabase } logger.info("Validating shard") - let validationResults = try ProcessKeywordDatabase + let validationResults = try await ProcessKeywordDatabase .validateShard(shard: processed, row: KeywordValuePair(keyword: row.key, value: row.value), - trials: config.trialsPerShard, context: context) + trials: config.trialsPerShard, context: context, using: PirUtil.self) let description = try validationResults.description() logger.info("ValidationResults \(description)") } @@ -540,9 +546,9 @@ struct ProcessDatabase: AsyncParsableCommand { let configData = try Data(contentsOf: configURL) let config = try JSONDecoder().decode(Arguments.self, from: configData) if config.rlweParameters.supportsScalar(UInt32.self) { - try await process(config: config, scheme: Bfv.self) + try await process(config: config, pirUtil: PirUtil>.self) } else { - try await process(config: config, scheme: Bfv.self) + try await process(config: config, pirUtil: PirUtil>.self) } } } diff --git a/Sources/PNNSProcessDatabase/ProcessDatabase.swift b/Sources/PNNSProcessDatabase/ProcessDatabase.swift index e3f24c81..ad0c8fb7 100644 --- a/Sources/PNNSProcessDatabase/ProcessDatabase.swift +++ b/Sources/PNNSProcessDatabase/ProcessDatabase.swift @@ -205,7 +205,7 @@ struct ProcessDatabase: ParsableCommand { /// - scheme: The HE scheme. /// - Throws: Error upon processing the database. @inlinable - mutating func process(config: Arguments, scheme _: Scheme.Type) throws { + mutating func process(config: Arguments, scheme _: Scheme.Type) async throws { let database = try Database(from: config.inputDatabase) guard let vectorDimension = database.rows.first?.vector.count else { throw PnnsError.emptyDatabase @@ -232,7 +232,7 @@ struct ProcessDatabase: ParsableCommand { let serverConfig = ServerConfig( clientConfig: clientConfig, databasePacking: config.databasePacking) - let processed = try database.process(config: serverConfig) + let processed = try await database.process(config: serverConfig) ProcessDatabase.logger.info("Processed database") if config.trials > 0 { @@ -242,7 +242,7 @@ struct ProcessDatabase: ParsableCommand { count: vectorDimension * (config.batchSize - queryRows.rowCount))) ProcessDatabase.logger.info("Validating") - let validationResult = try processed.validate(query: queryRows, trials: config.trials) + let validationResult = try await processed.validate(query: queryRows, trials: config.trials) for row in 0...self) + try await process(config: config, scheme: Bfv.self) } else { - try process(config: config, scheme: Bfv.self) + try await process(config: config, scheme: Bfv.self) } } } diff --git a/Sources/PrivateInformationRetrieval/IndexPir/IndexPirProtocol.swift b/Sources/PrivateInformationRetrieval/IndexPir/IndexPirProtocol.swift index 8be40a02..8e5e7439 100644 --- a/Sources/PrivateInformationRetrieval/IndexPir/IndexPirProtocol.swift +++ b/Sources/PrivateInformationRetrieval/IndexPir/IndexPirProtocol.swift @@ -167,7 +167,7 @@ public struct ProcessedDatabase: Equatable, Sendable { /// - context: Context for HE computation. /// - Throws: Error upon failure to load the database. @inlinable - public init(from path: String, context: Context) throws { + public init(from path: String, context: Scheme.Context) throws { let loadedFile = try [UInt8](Data(contentsOf: URL(fileURLWithPath: path))) try self.init(from: loadedFile, context: context) } @@ -178,7 +178,7 @@ public struct ProcessedDatabase: Equatable, Sendable { /// - context: Context for HE computation. /// - Throws: Error upon failure to deserialize. @inlinable - public init(from buffer: [UInt8], context: Context) throws { + public init(from buffer: [UInt8], context: Scheme.Context) throws { var offset = buffer.startIndex let versionNumber = buffer[offset] offset += MemoryLayout.size @@ -344,7 +344,7 @@ public protocol IndexPirProtocol { /// - config: Database configuration. /// - context: Context for HE computation. /// - Returns: The PIR parameters for the database. - static func generateParameter(config: IndexPirConfig, with context: Context) -> IndexPirParameter + static func generateParameter(config: IndexPirConfig, with context: Scheme.Context) -> IndexPirParameter } /// Protocol for a server hosting index PIR databases for lookup. @@ -380,7 +380,7 @@ public protocol IndexPirServer: Sendable { /// - context: Context for HE computation. /// - database: Integer-indexed database. /// - Throws: Error upon failure to initialize the server. - init(parameter: IndexPirParameter, context: Context, database: Database) throws + init(parameter: IndexPirParameter, context: Scheme.Context, database: Database) throws /// Initializes an ``IndexPirServer`` with databases. /// @@ -389,7 +389,7 @@ public protocol IndexPirServer: Sendable { /// - context: Context for HE computation. /// - databases: Integer-indexed databases, each compatible with the given `parameter`. /// - Throws: Error upon failure to initialize the server. - init(parameter: IndexPirParameter, context: Context, databases: [Database]) throws + init(parameter: IndexPirParameter, context: Scheme.Context, databases: [Database]) throws /// Processes the database to prepare for PIR queries. /// @@ -401,8 +401,8 @@ public protocol IndexPirServer: Sendable { /// - Returns: A processed database. /// - Throws: Error upon failure to process the database. static func process(database: some Collection<[UInt8]>, - with context: Context, - using parameter: IndexPirParameter) throws -> Database + with context: Scheme.Context, + using parameter: IndexPirParameter) async throws -> Database /// Compute the encrypted response to a query lookup. /// - Parameters: @@ -411,7 +411,7 @@ public protocol IndexPirServer: Sendable { /// - Returns: The encrypted response. /// - Throws: Error upon failure to compute a response. func computeResponse(to query: Query, - using evaluationKey: EvaluationKey) throws -> Response + using evaluationKey: EvaluationKey) async throws -> Response } extension IndexPirServer { @@ -422,7 +422,7 @@ extension IndexPirServer { /// - database: Database. /// - Throws: Error upon failure to initialize the server. @inlinable - public init(parameter: IndexPirParameter, context: Context, database: Database) throws { + public init(parameter: IndexPirParameter, context: Scheme.Context, database: Database) throws { try self.init(parameter: parameter, context: context, databases: [database]) } @@ -432,7 +432,9 @@ extension IndexPirServer { /// - context: Context for HE computation. /// - Returns: The PIR parameters for the database. @inlinable - public static func generateParameter(config: IndexPirConfig, with context: Context) -> IndexPirParameter { + public static func generateParameter(config: IndexPirConfig, + with context: Scheme.Context) -> IndexPirParameter + { IndexPir.generateParameter(config: config, with: context) } } @@ -460,7 +462,7 @@ public protocol IndexPirClient: Sendable { /// - Parameters: /// - parameter: Parameters for the database. /// - context: Context for HE computation. - init(parameter: IndexPirParameter, context: Context) + init(parameter: IndexPirParameter, context: Scheme.Context) /// Generates an encrypted query. /// - Parameters: @@ -537,3 +539,11 @@ extension Response { }.min() ?? -Double.infinity } } + +extension Response { + /// Returns `true` if all ciphertexts are transparent. + public func isTransparent() -> Bool { + ciphertexts.flatMap(\.self).allSatisfy + { ciphertext in ciphertext.isTransparent() } + } +} diff --git a/Sources/PrivateInformationRetrieval/IndexPir/MulPir.swift b/Sources/PrivateInformationRetrieval/IndexPir/MulPir.swift index f350e6a8..65053764 100644 --- a/Sources/PrivateInformationRetrieval/IndexPir/MulPir.swift +++ b/Sources/PrivateInformationRetrieval/IndexPir/MulPir.swift @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +import AsyncAlgorithms +import DequeModule import Foundation import HomomorphicEncryption import ModularArithmetic @@ -30,7 +32,9 @@ public enum MulPir: IndexPirProtocol { public static var algorithm: PirAlgorithm { .mulPir } - public static func generateParameter(config: IndexPirConfig, with context: Context) -> IndexPirParameter { + public static func generateParameter(config: IndexPirConfig, + with context: Scheme.Context) -> IndexPirParameter + { let entrySizeInBytes = config.entrySizeInBytes let perChunkPlaintextCount = if entrySizeInBytes <= context.bytesPerPlaintext { config.entryCount.dividingCeil(context.bytesPerPlaintext / entrySizeInBytes, variableTime: true) @@ -47,7 +51,7 @@ public enum MulPir: IndexPirProtocol { break } } - if config.unevenDimensions, config.dimensionCount == 2, Scheme.self == Bfv.self { + if config.unevenDimensions, config.dimensionCount == 2, Scheme.cryptosystem == .bfv { // BFV ciphertext-ciphertext multiply is a runtime bottleneck. // To improve runtime, we reduce the second dimension and // increase the first dimension while keeping the total expansion length @@ -108,9 +112,10 @@ public enum MulPir: IndexPirProtocol { } /// Client which can compute queries and decrypt responses using the ``PirAlgorithm/mulPir`` algorithm. -public final class MulPirClient: IndexPirClient { +public final class MulPirClient: IndexPirClient { @usableFromInline typealias Scalar = Scheme.Scalar - + /// Underlying HE scheme + public typealias Scheme = PirUtil.Scheme /// IndexPir protocol type. public typealias IndexPir = MulPir /// Encrypted query type. @@ -121,7 +126,7 @@ public final class MulPirClient: IndexPirClient { public let parameter: IndexPirParameter /// Context for HE computation. - public let context: HomomorphicEncryption.Context + public let context: Scheme.Context public var evaluationKeyConfig: EvaluationKeyConfig { parameter.evaluationKeyConfig @@ -140,7 +145,7 @@ public final class MulPirClient: IndexPirClient { IndexPir.computePerChunkPlaintextCount(for: parameter) } - public init(parameter: IndexPirParameter, context: Context) { + public init(parameter: IndexPirParameter, context: Scheme.Context) { self.parameter = parameter self.context = context } @@ -151,10 +156,8 @@ public final class MulPirClient: IndexPirClient { /// - Throws: Error upon failure to generate an evaluation key. /// - Warning: The evaluation key is only valid for use with the given `secretKey`. public func generateEvaluationKey(using secretKey: SecretKey) throws -> EvaluationKey { - try Scheme.generateEvaluationKey( - context: context, - config: evaluationKeyConfig, - using: secretKey) + try context.generateEvaluationKey(config: evaluationKeyConfig, + using: secretKey) } } @@ -194,15 +197,15 @@ extension MulPirClient { return coordinate } } - return try Query(ciphertexts: PirUtil.compressInputs( + return try Query(ciphertexts: PirUtil.compressBinaryInputs( totalInputCount: parameter.expandedQueryCount * indices.count, - nonZeroInputs: nonZeroPositions, + oneIndices: nonZeroPositions, context: context, using: secretKey), indicesCount: indices.count) } @inlinable - func plaintextIndex(entryIndex: Int) -> Int { + package func plaintextIndex(entryIndex: Int) -> Int { entryIndex / entryChunksPerPlaintext } } @@ -264,7 +267,9 @@ extension MulPirClient { } /// Server which can compute responses using the ``PirAlgorithm/mulPir`` algorithm. -public final class MulPirServer: IndexPirServer { +public final class MulPirServer: IndexPirServer { + /// Underlying HE scheme + public typealias Scheme = PirUtil.Scheme /// Index PIR type backing the keyword PIR computation. public typealias IndexPir = MulPir /// Encrypted query type. @@ -285,8 +290,7 @@ public final class MulPirServer: IndexPirServer { /// Context for HE computation. /// /// Must be the same between client and server. - public let context: HomomorphicEncryption.Context - + public let context: Scheme.Context /// Evaluation key configuration. public var evaluationKeyConfig: EvaluationKeyConfig { parameter.evaluationKeyConfig @@ -317,7 +321,7 @@ public final class MulPirServer: IndexPirServer { /// - context: Context for HE computation. /// - databases: Databases, each compatible with the given `parameter`. /// - Throws: Error upon failure to initialize the server. - public init(parameter: IndexPirParameter, context: Context, databases: [Database]) throws { + public init(parameter: IndexPirParameter, context: Scheme.Context, databases: [Database]) throws { self.parameter = parameter self.context = context self.databases = databases @@ -331,17 +335,18 @@ public final class MulPirServer: IndexPirServer { } @inlinable - package static func chunkCount(parameter: IndexPirParameter, context: Context) -> Int { + package static func chunkCount(parameter: IndexPirParameter, context: Scheme.Context) -> Int { parameter.entrySizeInBytes.dividingCeil(context.bytesPerPlaintext, variableTime: true) } } extension MulPirServer { @inlinable - func computeResponseForOneChunk(expandedDim0Query: [Ciphertext], - expandedRemainingQuery: ExpandedQueries, - dataChunk: DataChunk, - using evaluationKey: EvaluationKey) throws + func computeResponseForOneChunk( + expandedDim0Query: [Ciphertext], + expandedRemainingQuery: ExpandedQueries, + dataChunk: DataChunk, + using evaluationKey: EvaluationKey) async throws -> Ciphertext where ExpandedQueries: Collection, DataChunk: Collection?>, ExpandedQueries.Index == Int, DataChunk.Index == Int @@ -349,72 +354,78 @@ extension MulPirServer { let databaseColumnsCount = perChunkPlaintextCount / parameter.dimensions[0] precondition(databaseColumnsCount == 1 || databaseColumnsCount == expandedRemainingQuery.count) - var startIndex = dataChunk.startIndex - var intermediateResults: [CanonicalCiphertext] = try (0..] = + try await .init((0..) throws -> Response + using evaluationKey: EvaluationKey) async throws -> Response { guard databases.count == 1 || databases.count >= query.indicesCount else { throw PirError.invalidBatchSize(queryCount: query.indicesCount, databaseCount: databases.count) } - let expandedQueries = try PirUtil.expandCiphertexts( + let expandedQueries = try await PirUtil.expand(ciphertexts: query.ciphertexts, outputCount: parameter.expandedQueryCount * query.indicesCount, using: evaluationKey) + // This is a deque where remove first is a constant time op. + var ciphertextForEachQuery = expandedQueries.chunk(by: parameter.expandedQueryCount) + var responseCiphertexts: [[Scheme.CoeffCiphertext]] = [] - // Note that `parameter.expandedQueryCount` is the sum of all dimension sizes. We process the expanded - // queries in chunks of `parameter.expandedQueryCount`. In each chunk, we firstly convert the first - // `parameter.dimensions[0]` ciphertexts into eval format as they will multiply with plaintexts. The rest are - // queries for the remaining dimensions, multiplying with ciphertexts, thus can stay in canonical format. Then - // we simply use these queries to process every chunk of the database. The first iteration is looping over each - // PIR call. The second iteration is looping over chunks of entries. - return try Response(ciphertexts: (0.., with context: Context, - using parameter: IndexPirParameter) throws -> Database + public static func process(database: some Collection<[UInt8]>, with context: Scheme.Context, + using parameter: IndexPirParameter) async throws -> Database { guard database.count == parameter.entryCount else { throw PirError @@ -437,36 +448,37 @@ extension MulPirServer { } let chunkCount = parameter.entrySizeInBytes.dividingCeil(context.bytesPerPlaintext, variableTime: true) if chunkCount > 1 { - return try processSplitLargeEntries(database: database, with: context, using: parameter) + return try await processSplitLargeEntries(database: database, with: context, using: parameter) } - return try processPackEntries(database: database, with: context, using: parameter) + return try await processPackEntries(database: database, with: context, using: parameter) } @inlinable static func processSplitLargeEntries( database: some Collection<[UInt8]>, - with context: Context, - using parameter: IndexPirParameter) throws -> Database + with context: Scheme.Context, + using parameter: IndexPirParameter) async throws -> Database { let chunkCount = Self.chunkCount(parameter: parameter, context: context) - var plaintexts: [[Plaintext?]] = try database.map { entry in - try stride(from: 0, to: parameter.entrySizeInBytes, by: context.bytesPerPlaintext).map { startIndex in - let endIndex = min(startIndex + context.bytesPerPlaintext, entry.count) - // Avoid computing on padding plaintexts - guard startIndex < endIndex else { - return nil - } - let bytes = Array(entry[startIndex..?]] = try await .init(database.async.map { entry in + try await .init(stride(from: 0, to: parameter.entrySizeInBytes, by: context.bytesPerPlaintext).async + .map { startIndex in + let endIndex = min(startIndex + context.bytesPerPlaintext, entry.count) + // Avoid computing on padding plaintexts + guard startIndex < endIndex else { + return nil + } + let bytes = Array(entry[startIndex..?] = Array(repeatElement(nil, count: chunkCount)) @@ -490,8 +502,8 @@ extension MulPirServer { @inlinable static func processPackEntries( database: some Collection<[UInt8]>, - with context: Context, - using parameter: IndexPirParameter) throws -> Database + with context: Scheme.Context, + using parameter: IndexPirParameter) async throws -> Database { assert(database.count == parameter.entryCount) let flatDatabase: [UInt8] = database.flatMap { entry in @@ -502,7 +514,8 @@ extension MulPirServer { } let entriesPerPlaintext = context.bytesPerPlaintext / parameter.entrySizeInBytes let bytesPerPlaintext = entriesPerPlaintext * parameter.entrySizeInBytes - var plaintexts: [Plaintext?] = try stride(from: 0, to: flatDatabase.count, by: bytesPerPlaintext) + let plaintextIndices = stride(from: 0, to: flatDatabase.count, by: bytesPerPlaintext) + var plaintexts: [Plaintext?] = try await .init(plaintextIndices.async .map { startIndex in let endIndex = min(startIndex + bytesPerPlaintext, flatDatabase.count) let values = Array(flatDatabase[startIndex.. Deque> { + precondition(count.isMultiple(of: step)) + let shares = count / step + return Deque((0.. { - @usableFromInline typealias Scalar = Scheme.Scalar - @usableFromInline package typealias CanonicalCiphertext = Scheme.CanonicalCiphertext - typealias CoeffCiphertext = Scheme.CoeffCiphertext - typealias EvalCiphertext = Scheme.EvalCiphertext +/// A protocol outlining the auxiliary functionalities used in PIR. +public protocol PirUtilProtocol { + /// The underlying HE scheme. + associatedtype Scheme: HeScheme + /// The Scalar type used by the HE scheme. + associatedtype Scalar where Scalar == Scheme.Scalar + /// HE ciphertext in canonical format. + typealias CanonicalCiphertext = Scheme.CanonicalCiphertext + /// Expand a small number of ciphertexts to a large number of ciphertexts. + /// + /// Each output will be the encryption of a constant poly, where the constant of i-th output is the i-th coefficient + /// in the inputs. + /// - Parameters: + /// - ciphertexts: ciphertexts to expand + /// - outputCount: how many outputs are expected + /// - evaluationKey: evaluation key used for rotation and apply galois + /// - Returns: the expanded ciphertext + static func expand( + ciphertexts: consuming [CanonicalCiphertext], + outputCount: Int, + using evaluationKey: EvaluationKey) async throws -> [CanonicalCiphertext] + + /// Compress an binary array into ciphertexts such that the expanded ciphertexts is the original array. + /// + /// - Parameters: + /// - totalInputCount: the length of the binary array + /// - oneIndices: the position of 1s + /// - context: the context for HE + /// - secretKey: the secret key for encryption. + static func compressBinaryInputs( + totalInputCount: Int, + oneIndices: [Int], + context: Scheme.Context, + using secretKey: SecretKey) throws -> [CanonicalCiphertext] +} + +extension PirUtilProtocol { /// Convert one encrypted polynomial `c` to two encrypted polynomials, `p` and `q`. /// /// It is guaranteed that: @@ -41,7 +73,7 @@ package enum PirUtil { package static func expandCiphertextForOneStep( _ ciphertext: CanonicalCiphertext, logStep: Int, - using evaluationKey: EvaluationKey) throws -> (CanonicalCiphertext, CanonicalCiphertext) + using evaluationKey: EvaluationKey) async throws -> (CanonicalCiphertext, CanonicalCiphertext) { let degree = ciphertext.degree precondition(logStep <= degree.log2) @@ -57,16 +89,20 @@ package enum PirUtil { } let applyGaloisCount = 1 << ((targetElement - 1).log2 - (galoisElement - 1).log2) var currElement = 1 - for _ in 0.. { outputCount: Int, logStep: Int, expectedHeight: Int, - using evaluationKey: EvaluationKey) throws -> [CanonicalCiphertext] + using evaluationKey: EvaluationKey) async throws -> [CanonicalCiphertext] { precondition(outputCount >= 0 && outputCount <= ciphertext.degree) + var output = ciphertext if outputCount == 1 { if logStep > expectedHeight { return [ciphertext] } - return try [ciphertext + ciphertext] + try await Scheme.addAssignAsync(&output, ciphertext) + return [output] } let secondHalfCount = outputCount >> 1 let firstHalfCount = outputCount - secondHalfCount - let (p0, p1) = try expandCiphertextForOneStep( + let (p0, p1) = try await expandCiphertextForOneStep( ciphertext, logStep: logStep, using: evaluationKey) - let firstHalf = try expandCiphertext( + let firstHalf = try await expandCiphertext( p0, outputCount: firstHalfCount, logStep: logStep + 1, expectedHeight: expectedHeight, using: evaluationKey) - let secondHalf = try expandCiphertext( + let secondHalf = try await expandCiphertext( p1, outputCount: secondHalfCount, logStep: logStep + 1, @@ -119,32 +157,36 @@ package enum PirUtil { /// Expand a ciphertext array into given number of encrypted constant polynomials. @inlinable - package static func expandCiphertexts( - _ ciphertexts: [CanonicalCiphertext], - outputCount: Int, - using evaluationKey: EvaluationKey) throws -> [CanonicalCiphertext] + public static func expand(ciphertexts: consuming [CanonicalCiphertext], + outputCount: Int, + using evaluationKey: EvaluationKey) async throws -> [CanonicalCiphertext] { precondition((ciphertexts.count - 1) * ciphertexts[0].degree < outputCount) precondition(ciphertexts.count * ciphertexts[0].degree >= outputCount) var remainingOutputs = outputCount - return try ciphertexts.flatMap { ciphertext in + let lengths: [Int] = ciphertexts.compactMap { ciphertext in let outputToGenerate = min(remainingOutputs, ciphertext.degree) remainingOutputs -= outputToGenerate - return try expandCiphertext( - ciphertext, + return outputToGenerate + } + let expanded: [[CanonicalCiphertext]] = try await .init((0..) throws -> Plaintext + package static func compressInputsForOneCiphertext(totalInputCount: Int, oneIndices: [Int], + context: Scheme.Context) throws -> Plaintext { precondition(totalInputCount <= context.degree) var rawData: [Scalar] = Array(repeating: 0, count: context.degree) @@ -155,7 +197,7 @@ package enum PirUtil { modulus: context.plaintextModulus, variableTime: true).inverseMod(modulus: context.plaintextModulus, variableTime: true) - for index in nonZeroInputs { + for index in oneIndices { rawData[index] = inverseInputCountCeilLog } return try context.encode(values: rawData, format: .coefficient) @@ -163,10 +205,10 @@ package enum PirUtil { /// Generate the ciphertext based on the given non-zero positions. @inlinable - package static func compressInputs( + public static func compressBinaryInputs( totalInputCount: Int, - nonZeroInputs: [Int], - context: Context, + oneIndices: [Int], + context: Scheme.Context, using secretKey: SecretKey) throws -> [CanonicalCiphertext] { var remainingInputs = totalInputCount @@ -175,24 +217,18 @@ package enum PirUtil { while remainingInputs > 0 { let numberOfInputsToProcess = min(remainingInputs, context.degree) - let inputs = nonZeroInputs.filter { x in + let inputs = oneIndices.filter { x in (processedInputCount..<(processedInputCount + numberOfInputsToProcess)).contains(x) }.map { $0 - processedInputCount } try plaintexts.append(compressInputsForOneCiphertext( totalInputCount: numberOfInputsToProcess, - nonZeroInputs: inputs, + oneIndices: inputs, context: context)) processedInputCount += numberOfInputsToProcess remainingInputs -= numberOfInputsToProcess } return try plaintexts.map { plaintext in try plaintext.encrypt(using: secretKey) } } - - static func encodeDatabase(database: [[UInt8]], plaintextModulus: Scalar) throws -> [[Scalar]] { - try database.map { entry in - try CoefficientPacking.bytesToCoefficients(bytes: entry, - bitsPerCoeff: plaintextModulus.log2, - decode: false) - } - } } + +public enum PirUtil: PirUtilProtocol {} diff --git a/Sources/PrivateInformationRetrieval/KeywordPir/KeywordDatabase.swift b/Sources/PrivateInformationRetrieval/KeywordPir/KeywordDatabase.swift index 3d14c800..e8bb1c31 100644 --- a/Sources/PrivateInformationRetrieval/KeywordPir/KeywordDatabase.swift +++ b/Sources/PrivateInformationRetrieval/KeywordPir/KeywordDatabase.swift @@ -480,25 +480,29 @@ public enum ProcessKeywordDatabase { /// - Parameters: /// - shard: Shard of a keyword database. /// - arguments: Processing arguments. + /// - _: Type for auxiliary functionalities used in PIR. /// - onEvent: Function to call when a ``ProcessShardEvent`` happens. /// - Returns: The processed database. /// - Throws: Error upon failure to process the shard. @inlinable - public static func processShard(shard: KeywordDatabaseShard, - with arguments: Arguments, - onEvent: @escaping (ProcessShardEvent) throws -> Void = { _ in - }) throws - -> ProcessedDatabaseWithParameters + public static func processShard(shard: KeywordDatabaseShard, + with arguments: Arguments, + using _: PirUtil.Type, + onEvent: @escaping (ProcessShardEvent) throws + -> Void = { _ in + }) async throws + -> ProcessedDatabaseWithParameters { let keywordConfig = arguments.databaseConfig.keywordPirConfig - let context = try Context(encryptionParameters: arguments.encryptionParameters) + let context = try PirUtil.Scheme.Context(encryptionParameters: arguments.encryptionParameters) guard arguments.algorithm == .mulPir else { throw PirError.invalidPirAlgorithm(arguments.algorithm) } - return try KeywordPirServer>.process(database: shard, - config: keywordConfig, - with: context, onEvent: onEvent, - symmetricPirConfig: arguments.symmetricPirConfig) + return try await KeywordPirServer>.process(database: shard, + config: keywordConfig, + with: context, onEvent: onEvent, + symmetricPirConfig: arguments + .symmetricPirConfig) } /// Validates the correctness of processing on a shard. @@ -507,15 +511,18 @@ public enum ProcessKeywordDatabase { /// - row: Keyword-value pair to validate in a PIR query. /// - trials: How many PIR calls to validate. Must be > 0. /// - context: Context for HE computation. + /// - _: Type for auxiliary functionalities used in PIR. /// - Returns: The shard validation results. /// - Throws: Error upon failure to validate the sharding. - /// - seealso: ``ProcessKeywordDatabase/processShard(shard:with:onEvent:)`` to process a shard before validation. + /// - seealso: ``ProcessKeywordDatabase/processShard(shard:with:using:onEvent:)`` to process a shard before + /// validation. @inlinable - public static func validateShard( - shard: ProcessedDatabaseWithParameters, + public static func validateShard( + shard: ProcessedDatabaseWithParameters, row: KeywordValuePair, trials: Int, - context: Context) throws -> ShardValidationResult + context: PirUtil.Scheme.Context, + using _: PirUtil.Type) async throws -> ShardValidationResult { guard trials > 0 else { throw PirError.validationError("Invalid trialsPerShard: \(trials)") @@ -524,25 +531,26 @@ public enum ProcessKeywordDatabase { throw PirError.validationError("Shard missing keywordPirParameter") } - let server = try KeywordPirServer>( + let server = try KeywordPirServer>( context: context, processed: shard) - let client = KeywordPirClient>( + let client = KeywordPirClient>( keywordParameter: keywordPirParameter, pirParameter: shard.pirParameter, context: context) - var evaluationKey: EvaluationKey? - var query: Query? - var response = Response(ciphertexts: [[]]) + var evaluationKey: EvaluationKey? + var query: Query? + var response = Response(ciphertexts: [[]]) let clock = ContinuousClock() var minNoiseBudget = Double.infinity - let results = try (0..= Scheme.minNoiseBudget else { + guard noiseBudget >= PirUtil.Scheme.minNoiseBudget else { throw PirError.validationError("Insufficient noise budget \(noiseBudget)") } throw PirError.validationError("Incorrect PIR response") @@ -564,7 +572,7 @@ public enum ProcessKeywordDatabase { evaluationKey = trialEvaluationKey query = trialQuery } - return (computeTime, entryCount) + results[trial] = (computeTime, entryCount) } guard let evaluationKey, let query else { throw PirError.validationError("Empty evaluation key or query") @@ -586,30 +594,32 @@ public enum ProcessKeywordDatabase { /// - Parameters: /// - rows: Rows in the database. /// - arguments: Processing arguments. + /// - _: Type for auxiliary functionalities used in PIR. /// - Returns: The processed database. /// - Throws: Error upon failure to process the database. @inlinable - public static func process( + public static func process( rows: some Collection, - with arguments: Arguments) throws -> Processed + with arguments: Arguments, + using _: PirUtil.Type) async throws -> Processed { var evaluationKeyConfig = EvaluationKeyConfig() let keywordConfig = arguments.databaseConfig.keywordPirConfig - let context = try Context(encryptionParameters: arguments.encryptionParameters) + let context = try PirUtil.Scheme.Context(encryptionParameters: arguments.encryptionParameters) let keywordDatabase = try KeywordDatabase( rows: rows, sharding: arguments.databaseConfig.sharding, shardingFunction: keywordConfig.shardingFunction, symmetricPirConfig: arguments.symmetricPirConfig) - var processedShards = [String: ProcessedDatabaseWithParameters]() + var processedShards = [String: ProcessedDatabaseWithParameters]() for (shardID, shardedDatabase) in keywordDatabase.shards where !shardedDatabase.isEmpty { guard arguments.algorithm == .mulPir else { throw PirError.invalidPirAlgorithm(arguments.algorithm) } - let processed = try KeywordPirServer>.process(database: shardedDatabase, - config: keywordConfig, - with: context) + let processed = try await KeywordPirServer>.process(database: shardedDatabase, + config: keywordConfig, + with: context) evaluationKeyConfig = [evaluationKeyConfig, processed.pirParameter.evaluationKeyConfig] .union() diff --git a/Sources/PrivateInformationRetrieval/KeywordPir/KeywordPirProtocol.swift b/Sources/PrivateInformationRetrieval/KeywordPir/KeywordPirProtocol.swift index 034d3c1a..f1e2c8ac 100644 --- a/Sources/PrivateInformationRetrieval/KeywordPir/KeywordPirProtocol.swift +++ b/Sources/PrivateInformationRetrieval/KeywordPir/KeywordPirProtocol.swift @@ -154,7 +154,7 @@ public final class KeywordPirServer: KeywordPirProtoc /// - context: Context for HE computation. /// - processed: Processed database. /// - Throws: Error upon failure to initialize the server. - public required init(context: Context, + public required init(context: Scheme.Context, processed: ProcessedDatabaseWithParameters) throws { if let keywordPirParameter = processed.keywordPirParameter { @@ -186,10 +186,10 @@ public final class KeywordPirServer: KeywordPirProtoc @inlinable public static func process(database: some Collection, config: KeywordPirConfig, - with context: Context, + with context: Scheme.Context, onEvent: @escaping (ProcessKeywordDatabase.ProcessShardEvent) throws -> Void = { _ in }, symmetricPirConfig: SymmetricPirConfig? = nil) - throws -> ProcessedDatabaseWithParameters + async throws -> ProcessedDatabaseWithParameters { func onCuckooEvent(event: CuckooTable.Event) throws { try onEvent(ProcessKeywordDatabase.ProcessShardEvent.cuckooTableEvent(event)) @@ -221,16 +221,16 @@ public final class KeywordPirServer: KeywordPirProtoc unevenDimensions: config.unevenDimensions, keyCompression: config.keyCompression) let indexPirParameter = PirServer.generateParameter(config: indexPirConfig, with: context) - - let processedDb = try PirServer.Database(plaintexts: stride( - from: 0, - to: entryTable.count, - by: cuckooTable.bucketsPerTable).flatMap { startIndex in - try PirServer.process( + var plaintexts: [Plaintext?] = [] + let indices = Array(stride(from: 0, to: entryTable.count, by: cuckooTable.bucketsPerTable)) + for startIndex in indices { + let temp = try await PirServer.process( database: entryTable[startIndex..: KeywordPirProtoc /// - Throws: Error upon failure to compute a response. @inlinable public func computeResponse(to query: Query, - using evaluationKey: EvaluationKey) throws -> Response + using evaluationKey: EvaluationKey) async throws -> Response { - try indexPirServer.computeResponse(to: query, using: evaluationKey) + try await indexPirServer.computeResponse(to: query, using: evaluationKey) } } @@ -288,7 +288,7 @@ public final class KeywordPirClient: KeywordPirProtoc public required init( keywordParameter: KeywordPirParameter, pirParameter: IndexPirParameter, - context: Context) + context: Scheme.Context) { self.keywordParameter = keywordParameter self.indexPirClient = PirClient(parameter: pirParameter, context: context) diff --git a/Sources/PrivateInformationRetrieval/PrivateInformationRetrieval.docc/EncodingPipeline.md b/Sources/PrivateInformationRetrieval/PrivateInformationRetrieval.docc/EncodingPipeline.md index 8c78ef12..98d6f74c 100644 --- a/Sources/PrivateInformationRetrieval/PrivateInformationRetrieval.docc/EncodingPipeline.md +++ b/Sources/PrivateInformationRetrieval/PrivateInformationRetrieval.docc/EncodingPipeline.md @@ -52,7 +52,7 @@ data. Each data row needs to be converted to ``KeywordValuePair`` with ``Keyword `keyword` and `value` are both `[UInt8]`. Once you have a collection of ``KeywordValuePair``s you have two options: 1. You can use ``KeywordPirServer/process(database:config:with:onEvent:symmetricPirConfig:)`` to process the shard directly. 2. Or you can construct a ``KeywordDatabaseShard`` by using ``KeywordDatabaseShard/init(shardID:rows:)`` and then -``ProcessKeywordDatabase/processShard(shard:with:onEvent:)``. +``ProcessKeywordDatabase/processShard(shard:with:using:onEvent:)``. Both options give as output a ``ProcessedDatabaseWithParameters``. @@ -70,8 +70,8 @@ let pirParameters = try processedDatabaseWithParameters.proto(context: context) ## Loading processed shard To load a processed shard, one needs two parts: -1. ``ProcessedDatabase`` can be loaded using ``ProcessedDatabase/init(from:context:)-9ppkq`` or -``ProcessedDatabase/init(from:context:)-4pmcl``. +1. ``ProcessedDatabase`` can be loaded using ``ProcessedDatabase/init(from:context:)-(String,_)`` or +``ProcessedDatabase/init(from:context:)-([UInt8],_)``. 2. Use the `pirParameters` from protobuf and add them in like this: ```swift diff --git a/Sources/PrivateNearestNeighborSearch/CiphertextMatrix.swift b/Sources/PrivateNearestNeighborSearch/CiphertextMatrix.swift index e65f20b7..e584074b 100644 --- a/Sources/PrivateNearestNeighborSearch/CiphertextMatrix.swift +++ b/Sources/PrivateNearestNeighborSearch/CiphertextMatrix.swift @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +import AsyncAlgorithms import HomomorphicEncryption import ModularArithmetic @@ -32,7 +33,7 @@ public struct CiphertextMatrix: Equatable, @usableFromInline package var ciphertexts: [Ciphertext] /// The parameter context. - @usableFromInline var context: Context { + @usableFromInline var context: Scheme.Context { precondition(!ciphertexts.isEmpty, "Ciphertext array cannot be empty") return ciphertexts[0].context } @@ -114,6 +115,22 @@ extension CiphertextMatrix { ciphertexts: evalCiphertexts) } + /// Async version of ``convertToEvalFormat()``. + @inlinable + public func convertToEvalFormat() async throws -> CiphertextMatrix { + if Format.self == Eval.self { + // swiftlint:disable:next force_cast + return self as! CiphertextMatrix + } + let evalCiphertexts: [Ciphertext] = try await .init(ciphertexts.async.map { ciphertext in + try await ciphertext.convertToEvalFormat() + }) + return try CiphertextMatrix( + dimensions: dimensions, + packing: packing, + ciphertexts: evalCiphertexts) + } + /// Converts the ciphertext matrix to `Coeff` format. /// - Returns: The converted ciphertext matrix. /// - Throws: Error upon failure to convert the ciphertext matrix. @@ -130,6 +147,22 @@ extension CiphertextMatrix { ciphertexts: coeffCiphertexts) } + /// Async version of ``convertToCoeffFormat()``. + @inlinable + public func convertToCoeffFormat() async throws -> CiphertextMatrix { + if Format.self == Coeff.self { + // swiftlint:disable:next force_cast + return self as! CiphertextMatrix + } + let coeffCiphertexts: [Ciphertext] = try await .init(ciphertexts.async.map { ciphertext in + try ciphertext.convertToCoeffFormat() + }) + return try CiphertextMatrix( + dimensions: dimensions, + packing: packing, + ciphertexts: coeffCiphertexts) + } + /// Converts the ciphertext matrix to canonical format. /// - Returns: The converted ciphertext matrix. /// - Throws: Error upon failure to convert the ciphertext matrix. @@ -146,14 +179,28 @@ extension CiphertextMatrix { fatalError("Unsupported Format \(Format.description)") } + /// Async version of ``convertToCanonicalFormat()``. + @inlinable + public func convertToCanonicalFormat() async throws -> CiphertextMatrix { + if Scheme.CanonicalCiphertextFormat.self == Coeff.self { + // swiftlint:disable:next force_cast + return try await convertToCoeffFormat() as! CiphertextMatrix + } + if Scheme.CanonicalCiphertextFormat.self == Eval.self { + // swiftlint:disable:next force_cast + return try await convertToEvalFormat() as! CiphertextMatrix + } + fatalError("Unsupported Format \(Format.description)") + } + /// Performs modulus switching to a single modulus. /// /// If the ciphertexts already have a single modulus, this is a no-op. /// - Throws: Error upon failure to modulus switch. @inlinable - public mutating func modSwitchDownToSingle() throws where Format == Scheme.CanonicalCiphertextFormat { + public mutating func modSwitchDownToSingle() async throws where Format == Scheme.CanonicalCiphertextFormat { for index in 0..) throws -> Self + package func extractDenseRow(rowIndex: Int, evaluationKey: EvaluationKey) async throws -> Self where Format == Scheme.CanonicalCiphertextFormat { precondition((0.. { public let config: ClientConfig /// One context per plaintext modulus. - public let contexts: [Context] + public let contexts: [Scheme.Context] /// Performs composition of the plaintext CRT responses. @usableFromInline let crtComposer: CrtComposer @@ -43,7 +43,7 @@ public struct Client { /// - contexts: Contexts for HE computation, one per plaintext modulus. /// - Throws: Error upon failure to create the client. @inlinable - public init(config: ClientConfig, contexts: [Context] = []) throws { + public init(config: ClientConfig, contexts: [Scheme.Context] = []) throws { guard config.distanceMetric == .cosineSimilarity else { throw PnnsError.wrongDistanceMetric(got: config.distanceMetric, expected: .cosineSimilarity) } @@ -52,7 +52,7 @@ public struct Client { var contexts = contexts if contexts.isEmpty { contexts = try config.encryptionParameters.map { encryptionParameters in - try Context(encryptionParameters: encryptionParameters) + try Scheme.Context(encryptionParameters: encryptionParameters) } } try config.validateContexts(contexts: contexts) @@ -79,7 +79,7 @@ public struct Client { let matrices = try contexts.map { context in // For a single plaintext modulus, reduction isn't necessary let shouldReduce = contexts.count > 1 - let plaintextMatrix = try PlaintextMatrix( + let plaintextMatrix = try PlaintextMatrix( context: context, dimensions: MatrixDimensions(vectors.shape), packing: config.queryPacking, diff --git a/Sources/PrivateNearestNeighborSearch/Config.swift b/Sources/PrivateNearestNeighborSearch/Config.swift index 762603e3..b232e4ae 100644 --- a/Sources/PrivateNearestNeighborSearch/Config.swift +++ b/Sources/PrivateNearestNeighborSearch/Config.swift @@ -121,7 +121,7 @@ public struct ClientConfig: Codable, Equatable, Hashable, Send /// - Parameter contexts: Contexts; one per plaintext modulus. /// - Throws: Error if the contexts are not valid. @inlinable - func validateContexts(contexts: [Context]) throws { + func validateContexts(contexts: [Scheme.Context]) throws { guard contexts.count == encryptionParameters.count else { throw PnnsError.wrongContextsCount(got: contexts.count, expected: encryptionParameters.count) } @@ -190,7 +190,7 @@ public struct ServerConfig: Codable, Equatable, Hashable, Send /// - Parameter contexts: Contexts; one per plaintext modulus. /// - Throws: Error if the contexts are not valid. @inlinable - func validateContexts(contexts: [Context]) throws { + func validateContexts(contexts: [Scheme.Context]) throws { try clientConfig.validateContexts(contexts: contexts) } } diff --git a/Sources/PrivateNearestNeighborSearch/Error.swift b/Sources/PrivateNearestNeighborSearch/Error.swift index 54620047..2ab5bd29 100644 --- a/Sources/PrivateNearestNeighborSearch/Error.swift +++ b/Sources/PrivateNearestNeighborSearch/Error.swift @@ -56,7 +56,7 @@ extension PnnsError { } @inlinable - static func wrongContext(got: Context, expected: Context) -> Self { + static func wrongContext(got: some HeContext, expected: some HeContext) -> Self { PnnsError.wrongContext(gotDescription: got.description, expectedDescription: expected.description) } diff --git a/Sources/PrivateNearestNeighborSearch/MatrixMultiplication.swift b/Sources/PrivateNearestNeighborSearch/MatrixMultiplication.swift index 7eb711ce..a939e20c 100644 --- a/Sources/PrivateNearestNeighborSearch/MatrixMultiplication.swift +++ b/Sources/PrivateNearestNeighborSearch/MatrixMultiplication.swift @@ -14,6 +14,7 @@ import _HomomorphicEncryptionExtras import Algorithms +import AsyncAlgorithms import Foundation import HomomorphicEncryption import ModularArithmetic @@ -55,13 +56,14 @@ public struct BabyStepGiantStep: Codable, Equatable, Hashable, Sendable { /// Utilities for matrix multiplication. public enum MatrixMultiplication { + // swiftformat:disable unusedArguments /// Computes the evaluation key configuration for matrix multiplication. /// - Parameters: /// - plaintextMatrixDimensions: Dimensions of the plaintext matrix. - /// - maxQueryCount: Maximum number of queries in one batch. The returned`EvaluationKeyConfig` will support all /// - encryptionParameters: Encryption paramterss + /// - maxQueryCount: Maximum number of queries in one batch. The returned`EvaluationKeyConfig` will support + /// all batch sizes up to and including `maxQueryCount`. /// - scheme: The metatype of the generic parameter `Scheme`. - /// batch sizes up to and including `maxQueryCount`. /// - Returns: The evaluation key configuration. /// - Throws: Error upon failure to compute the configuration. @inlinable @@ -122,7 +124,7 @@ extension PlaintextMatrix { @inlinable package func mulTranspose( vector ciphertextVector: CiphertextMatrix, - using evaluationKey: EvaluationKey) throws -> [Scheme.CanonicalCiphertext] + using evaluationKey: EvaluationKey) async throws -> [Scheme.CanonicalCiphertext] { guard case .diagonal = packing else { let expectedBsgs = BabyStepGiantStep(vectorDimension: dimensions.columnCount) @@ -167,47 +169,57 @@ extension PlaintextMatrix { let babyStepGiantStep = BabyStepGiantStep(vectorDimension: dimensions.columnCount) // 1) Compute v_j = theta^j(v) - var rotatedCiphertexts: [Scheme.EvalCiphertext] = [] - rotatedCiphertexts.reserveCapacity(babyStepGiantStep.babyStep) + var rotatedStates: [Scheme.CanonicalCiphertext] = [] + rotatedStates.reserveCapacity(babyStepGiantStep.babyStep) + var state = ciphertextVector.ciphertexts[0] for step in 0.. Scheme.CanonicalCiphertext = { giantStepIndex, resultCiphertextIndex in let plaintextCount = min( rotatedCiphertexts.count, babyStepGiantStep.vectorDimension - babyStepGiantStep.babyStep * giantStepIndex) - let plaintextRows = try (0..] = try await .init(plaintextRowIndices.async.map { index in + try plaintexts[index].convertToEvalFormat() + }) + let ciphertexts = rotatedCiphertexts[0.., - using evaluationKey: EvaluationKey) throws + using evaluationKey: EvaluationKey) async throws -> CiphertextMatrix { guard dimensions.columnCount == ciphertextMatrix.dimensions.columnCount else { @@ -240,34 +252,37 @@ extension PlaintextMatrix { throw PnnsError.incorrectSimdRowsCount(got: simdDimensions.rowCount, expected: 2) } - var innerProducts: [Scheme.CanonicalCiphertext] = try (0.. 0 { let columnsPerCiphertextCount = simdRowCount * columnsPerSimdRowCount - let packedCiphertexts = try innerProducts.chunks(ofCount: columnsPerCiphertextCount) + let packedCiphertexts: [Scheme.CanonicalCiphertext] = try await .init(innerProducts + .chunks(ofCount: columnsPerCiphertextCount).async .map { columnsForCiphertext in - var packedRows: [Scheme.CanonicalCiphertext] = try columnsForCiphertext - .chunks(ofCount: columnsPerSimdRowCount).map { columnsForRow in - guard var packedRow = columnsForRow.last else { - throw PnnsError.emptyCiphertextArray - } - for column in columnsForRow.dropLast().reversed() { - try packedRow.rotateColumnsMultiStep(by: dimensions.rowCount, using: evaluationKey) - try packedRow += column - } - return packedRow - } + let packedRows: [Scheme.CanonicalCiphertext] = try await .init(columnsForCiphertext + .chunks(ofCount: columnsPerSimdRowCount).async.map { columnsForRow in + try await Scheme.rotateColumnsAndSumAsync( + Array(columnsForRow), + by: dimensions.rowCount, + using: evaluationKey) + }) if columnsForCiphertext.count > columnsPerSimdRowCount { - try packedRows[1].swapRows(using: evaluationKey) - return try packedRows[0] + packedRows[1] + return try await Scheme.swapRowsAndAddAsync( + swapping: packedRows[1], + addingTo: packedRows[0], + using: evaluationKey) } return packedRows[0] - } + }) innerProducts = packedCiphertexts } let resultMatrixDimensions = try MatrixDimensions( diff --git a/Sources/PrivateNearestNeighborSearch/PlaintextMatrix.swift b/Sources/PrivateNearestNeighborSearch/PlaintextMatrix.swift index 2b203e38..752c5893 100644 --- a/Sources/PrivateNearestNeighborSearch/PlaintextMatrix.swift +++ b/Sources/PrivateNearestNeighborSearch/PlaintextMatrix.swift @@ -88,7 +88,7 @@ public struct PlaintextMatrix: Equatable, @usableFromInline package let plaintexts: [Plaintext] /// The parameter context. - @usableFromInline package var context: Context { + @usableFromInline package var context: Scheme.Context { precondition(!plaintexts.isEmpty, "Plaintext array cannot be empty") return plaintexts[0].context } @@ -153,7 +153,7 @@ public struct PlaintextMatrix: Equatable, /// - Throws: Error upon failure to create the plaitnext matrix. @inlinable public init( - context: Context, + context: Scheme.Context, dimensions: MatrixDimensions, packing: MatrixPacking, signedValues: [Scheme.SignedScalar], @@ -187,7 +187,7 @@ public struct PlaintextMatrix: Equatable, /// - Throws: Error upon failure to create the plaitnext matrix. @inlinable package init( - context: Context, + context: Scheme.Context, dimensions: MatrixDimensions, packing: MatrixPacking, values: [Scalar], @@ -272,11 +272,11 @@ public struct PlaintextMatrix: Equatable, /// - Returns: The plaintexts for `denseColumn` packing. /// - Throws: Error upon plaintext to compute the plaintexts. @inlinable - static func denseColumnPlaintexts(context: Context, dimensions: MatrixDimensions, + static func denseColumnPlaintexts(context: Scheme.Context, dimensions: MatrixDimensions, values: [Scalar]) throws -> [Scheme.CoeffPlaintext] { let degree = context.degree - guard let simdColumnCount = context.simdDimensions?.columnCount else { + guard let simdColumnCount = Scheme.simdDimensions(for: context.encryptionParameters)?.columnCount else { throw PnnsError.simdEncodingNotSupported(for: context.encryptionParameters) } @@ -329,12 +329,12 @@ public struct PlaintextMatrix: Equatable, /// - Throws: Error upon failure to compute the plaintexts. @inlinable static func denseRowPlaintexts( - context: Context, + context: Scheme.Context, dimensions: MatrixDimensions, values: [Scalar]) throws -> [Plaintext] { let encryptionParameters = context.encryptionParameters - guard let simdDimensions = context.simdDimensions else { + guard let simdDimensions = Scheme.simdDimensions(for: context.encryptionParameters) else { throw PnnsError.simdEncodingNotSupported(for: encryptionParameters) } guard simdDimensions.rowCount == 2 else { @@ -405,13 +405,13 @@ public struct PlaintextMatrix: Equatable, /// - Throws: Error upon failure to compute the plaintexts. @inlinable static func diagonalPlaintexts( - context: Context, + context: Scheme.Context, dimensions: MatrixDimensions, packing: MatrixPacking, values: [Scalar]) throws -> [Scheme.CoeffPlaintext] { let encryptionParameters = context.encryptionParameters - guard let simdDimensions = context.simdDimensions else { + guard let simdDimensions = Scheme.simdDimensions(for: context.encryptionParameters) else { throw PnnsError.simdEncodingNotSupported(for: encryptionParameters) } let simdColumnCount = simdDimensions.columnCount @@ -463,7 +463,7 @@ public struct PlaintextMatrix: Equatable, chunk[chunk.startIndex.. plaintexts.append(plaintext) } } diff --git a/Sources/PrivateNearestNeighborSearch/ProcessedDatabase.swift b/Sources/PrivateNearestNeighborSearch/ProcessedDatabase.swift index af3791f4..70d0eedd 100644 --- a/Sources/PrivateNearestNeighborSearch/ProcessedDatabase.swift +++ b/Sources/PrivateNearestNeighborSearch/ProcessedDatabase.swift @@ -1,4 +1,4 @@ -// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ import HomomorphicEncryption /// A database after processing to prepare for PNNS queries. public struct ProcessedDatabase: Equatable, Sendable { /// One context per plaintext modulus. - public let contexts: [Context] + public let contexts: [Scheme.Context] /// The processed vectors in the database. public let plaintextMatrices: [PlaintextMatrix] @@ -33,7 +33,7 @@ public struct ProcessedDatabase: Equatable, Sendable { @inlinable public init( - contexts: [Context], + contexts: [Scheme.Context], plaintextMatrices: [PlaintextMatrix], entryIds: [UInt64], entryMetadatas: [[UInt8]], @@ -52,11 +52,11 @@ public struct ProcessedDatabase: Equatable, Sendable { /// - serialized: Serialized processed database. /// - contexts: Contexts for HE computation, one per plaintext modulus. /// - Throws: Error upon failure to load the database. - public init(from serialized: SerializedProcessedDatabase, contexts: [Context] = []) throws { + public init(from serialized: SerializedProcessedDatabase, contexts: [Scheme.Context] = []) throws { var contexts = contexts if contexts.isEmpty { contexts = try serialized.serverConfig.encryptionParameters.map { encryptionParameters in - try Context(encryptionParameters: encryptionParameters) + try Scheme.Context(encryptionParameters: encryptionParameters) } } try serialized.serverConfig.validateContexts(contexts: contexts) @@ -86,7 +86,7 @@ public struct ProcessedDatabase: Equatable, Sendable { } @inlinable - public func validate(query vector: Array2d, trials: Int) throws -> ValidationResult { + public func validate(query vector: Array2d, trials: Int) async throws -> ValidationResult { guard trials > 0 else { throw PnnsError.validationError("Invalid trialsPerShard: \(trials)") } @@ -104,12 +104,13 @@ public struct ProcessedDatabase: Equatable, Sendable { var databaseDistances = DatabaseDistances() let clock = ContinuousClock() var minNoiseBudget = Double.infinity - let computeTimes = try (0..= Scheme.minNoiseBudget else { @@ -123,7 +124,7 @@ public struct ProcessedDatabase: Equatable, Sendable { query = trialQuery databaseDistances = trialDatabaseDistances } - return computeTime + computeTimes[trial] = computeTime } guard let evaluationKey, let query else { throw PnnsError.validationError("Empty evaluation key or query") @@ -188,7 +189,7 @@ extension Database { /// - Throws: Error upon failure to process the database. @inlinable public func process(config: ServerConfig, - contexts: [Context] = []) throws -> ProcessedDatabase + contexts: [Scheme.Context] = []) async throws -> ProcessedDatabase { guard config.distanceMetric == .cosineSimilarity else { throw PnnsError.wrongDistanceMetric(got: config.distanceMetric, expected: .cosineSimilarity) @@ -196,7 +197,7 @@ extension Database { var contexts = contexts if contexts.isEmpty { contexts = try config.encryptionParameters.map { encryptionParameters in - try Context(encryptionParameters: encryptionParameters) + try Scheme.Context(encryptionParameters: encryptionParameters) } } try config.validateContexts(contexts: contexts) diff --git a/Sources/PrivateNearestNeighborSearch/SerializedCiphertextMatrix.swift b/Sources/PrivateNearestNeighborSearch/SerializedCiphertextMatrix.swift index 793198b5..2c7cf9f2 100644 --- a/Sources/PrivateNearestNeighborSearch/SerializedCiphertextMatrix.swift +++ b/Sources/PrivateNearestNeighborSearch/SerializedCiphertextMatrix.swift @@ -54,7 +54,7 @@ extension CiphertextMatrix { @inlinable public init( deserialize serialized: SerializedCiphertextMatrix, - context: Context, + context: Scheme.Context, moduliCount: Int? = nil) throws { let ciphertexts: [Ciphertext] = try serialized.ciphertexts.map { serializedCiphertext in diff --git a/Sources/PrivateNearestNeighborSearch/SerializedPlaintextMatrix.swift b/Sources/PrivateNearestNeighborSearch/SerializedPlaintextMatrix.swift index ab6624b1..ff5e4342 100644 --- a/Sources/PrivateNearestNeighborSearch/SerializedPlaintextMatrix.swift +++ b/Sources/PrivateNearestNeighborSearch/SerializedPlaintextMatrix.swift @@ -1,4 +1,4 @@ -// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -60,9 +60,9 @@ extension PlaintextMatrix where Format == Coeff { /// - serialized: Serialized plaintext matrix. /// - context: Context to associate with the plaintext matrix. /// - Throws: Error upon failure to deserialize. - init(deserialize serialized: SerializedPlaintextMatrix, context: Context) throws { + init(deserialize serialized: SerializedPlaintextMatrix, context: Scheme.Context) throws { let plaintexts = try serialized.plaintexts.map { serializedPlaintext in - try Plaintext(deserialize: serializedPlaintext, context: context) + try Plaintext(deserialize: serializedPlaintext, context: context) } try self.init(dimensions: serialized.dimensions, packing: serialized.packing, plaintexts: plaintexts) } @@ -78,11 +78,11 @@ extension PlaintextMatrix where Format == Eval { /// - Throws: Error upon failure to deserialize. init( deserialize serialized: SerializedPlaintextMatrix, - context: Context, + context: Scheme.Context, moduliCount: Int? = nil) throws { let plaintexts = try serialized.plaintexts.map { serializedPlaintext in - try Plaintext(deserialize: serializedPlaintext, context: context, moduliCount: moduliCount) + try Plaintext(deserialize: serializedPlaintext, context: context, moduliCount: moduliCount) } try self.init(dimensions: serialized.dimensions, packing: serialized.packing, plaintexts: plaintexts) } diff --git a/Sources/PrivateNearestNeighborSearch/Server.swift b/Sources/PrivateNearestNeighborSearch/Server.swift index 6f75071d..f10bd041 100644 --- a/Sources/PrivateNearestNeighborSearch/Server.swift +++ b/Sources/PrivateNearestNeighborSearch/Server.swift @@ -1,4 +1,4 @@ -// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +import AsyncAlgorithms import HomomorphicEncryption /// Private nearest neighbor server. @@ -35,7 +36,7 @@ public struct Server: Sendable { } /// One context per plaintext modulus. - public var contexts: [Context] { + public var contexts: [Scheme.Context] { database.contexts } @@ -58,23 +59,27 @@ public struct Server: Sendable { /// - Throws: Error upon failure to compute a response. @inlinable public func computeResponse(to query: Query, - using evaluationKey: EvaluationKey) throws -> Response + using evaluationKey: EvaluationKey) async throws -> Response { guard query.ciphertextMatrices.count == database.plaintextMatrices.count else { throw PnnsError.invalidQuery(reason: InvalidQueryReason.wrongCiphertextMatrixCount( got: query.ciphertextMatrices.count, expected: database.plaintextMatrices.count)) } - - let responseMatrices = try zip(query.ciphertextMatrices, database.plaintextMatrices) - .map { ciphertextMatrix, plaintextMatrix in - var responseMatrix = try plaintextMatrix.mulTranspose( - matrix: ciphertextMatrix.convertToCanonicalFormat(), + let asyncCiphertextMatrices: [CiphertextMatrix] = + try await .init(query.ciphertextMatrices.async.map { try $0.convertToCanonicalFormat() }) + let asyncPlaintextMatrices: [PlaintextMatrix] = database.plaintextMatrices + let responseMatrices: [CiphertextMatrix] = try await .init(zip( + asyncCiphertextMatrices, + asyncPlaintextMatrices) + .async.map { ciphertextMatrix, plaintextMatrix in + var responseMatrix = try await plaintextMatrix.mulTranspose( + matrix: ciphertextMatrix, using: evaluationKey) // Reduce response size by mod-switching to a single modulus. - try responseMatrix.modSwitchDownToSingle() - return try responseMatrix.convertToCoeffFormat() - } + try await responseMatrix.modSwitchDownToSingle() + return try await responseMatrix.convertToCoeffFormat() + }) return Response( ciphertextMatrices: responseMatrices, diff --git a/Sources/_BenchmarkUtilities/PirBenchmarkUtilities.swift b/Sources/_BenchmarkUtilities/PirBenchmarkUtilities.swift index 903299ab..381158a9 100644 --- a/Sources/_BenchmarkUtilities/PirBenchmarkUtilities.swift +++ b/Sources/_BenchmarkUtilities/PirBenchmarkUtilities.swift @@ -123,7 +123,7 @@ extension PrivateInformationRetrieval.Response { struct ProcessBenchmarkContext { let database: [[UInt8]] - let context: Context + let context: Server.Scheme.Context let parameter: IndexPirParameter init(server _: Server.Type, pirConfig: IndexPirConfig, encryptionConfig: EncryptionParametersConfig) throws @@ -133,22 +133,22 @@ struct ProcessBenchmarkContext { self.database = getDatabaseForTesting( numberOfEntries: pirConfig.entryCount, entrySizeInBytes: pirConfig.entrySizeInBytes) - self.context = try Context(encryptionParameters: encryptParameter) + self.context = try Server.Scheme.Context(encryptionParameters: encryptParameter) self.parameter = Server.generateParameter(config: pirConfig, with: context) } } /// Pre-processing database benchmark. -public func pirProcessBenchmark( - _: Scheme.Type, +public func pirProcessBenchmark( + _: PirUtil.Type, // swiftlint:disable:next force_try - config: PirBenchmarkConfig = try! .init()) -> () -> Void + config: PirBenchmarkConfig = try! .init()) -> () -> Void { { let databaseConfig = config.databaseConfig let benchmarkName = [ "Process", - String(describing: Scheme.self), + String(describing: PirUtil.Scheme.self), config.encryptionConfig.description, "entryCount=\(databaseConfig.entryCount)", "entrySize=\(databaseConfig.entrySizeInBytes)", @@ -157,17 +157,17 @@ public func pirProcessBenchmark( // swiftlint:disable closure_parameter_position Benchmark(benchmarkName, configuration: config.benchmarkConfig) { ( benchmark, - benchmarkContext: ProcessBenchmarkContext>) in + benchmarkContext: ProcessBenchmarkContext>) in for _ in benchmark.scaledIterations { - try blackHole( - MulPirServer.process( + try await blackHole( + MulPirServer.process( database: benchmarkContext.database, with: benchmarkContext.context, using: benchmarkContext.parameter)) } } setup: { try ProcessBenchmarkContext( - server: MulPirServer.self, + server: MulPirServer.self, pirConfig: config.indexPirConfig, encryptionConfig: config.encryptionConfig) } @@ -196,16 +196,16 @@ struct IndexPirBenchmarkContext server _: Server.Type, client _: Client.Type, pirConfig: IndexPirConfig, - encryptionConfig: EncryptionParametersConfig) throws + encryptionConfig: EncryptionParametersConfig) async throws { let encryptParameter: EncryptionParameters = try EncryptionParameters(from: encryptionConfig) - let context = try Context(encryptionParameters: encryptParameter) + let context = try Server.Scheme.Context(encryptionParameters: encryptParameter) let indexPirParameters = Server.generateParameter(config: pirConfig, with: context) let database = getDatabaseForTesting( numberOfEntries: pirConfig.entryCount, entrySizeInBytes: pirConfig.entrySizeInBytes) - self.processedDatabase = try Server.process(database: database, with: context, using: indexPirParameters) + self.processedDatabase = try await Server.process(database: database, with: context, using: indexPirParameters) self.server = try Server(parameter: indexPirParameters, context: context, database: processedDatabase) self.client = Client(parameter: indexPirParameters, context: context) @@ -216,7 +216,7 @@ struct IndexPirBenchmarkContext // Validate correctness let queryIndex = Int.random(in: 0.. } /// IndexPIR benchmark. -public func indexPirBenchmark( - _: Scheme.Type, +public func indexPirBenchmark( + _: PirUtil.Type, // swiftlint:disable:next force_try - config: PirBenchmarkConfig = try! .init()) -> () -> Void + config: PirBenchmarkConfig = try! .init()) -> () -> Void { { let benchmarkName = [ "IndexPir", - String(describing: Scheme.self), + String(describing: PirUtil.Scheme.self), config.encryptionConfig.description, "entryCount=\(config.databaseConfig.entryCount)", "entrySize=\(config.databaseConfig.entrySizeInBytes)", @@ -250,11 +250,11 @@ public func indexPirBenchmark( // swiftlint:disable closure_parameter_position Benchmark(benchmarkName, configuration: config.benchmarkConfig) { ( benchmark, - benchmarkContext: IndexPirBenchmarkContext, MulPirClient>) in + benchmarkContext: IndexPirBenchmarkContext, MulPirClient>) in for _ in benchmark.scaledIterations { - try blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query, - using: benchmarkContext - .evaluationKey)) + try await blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query, + using: benchmarkContext + .evaluationKey)) } benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize) benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount) @@ -266,9 +266,9 @@ public func indexPirBenchmark( } // swiftlint:enable closure_parameter_position setup: { - try IndexPirBenchmarkContext( - server: MulPirServer.self, - client: MulPirClient.self, + try await IndexPirBenchmarkContext( + server: MulPirServer.self, + client: MulPirClient.self, pirConfig: config.indexPirConfig, encryptionConfig: config.encryptionConfig) } @@ -296,7 +296,7 @@ struct KeywordPirBenchmarkContext) async throws { let encryptParameter: EncryptionParameters = try EncryptionParameters(from: config.encryptionConfig) - let context = try Context(encryptionParameters: encryptParameter) + let context = try Server.Scheme.Context(encryptionParameters: encryptParameter) let rows = (0..( - _: Scheme.Type, +public func keywordPirBenchmark( + _: PirUtil.Type, // swiftlint:disable:next force_try - config: PirBenchmarkConfig = try! .init()) -> () -> Void + config: PirBenchmarkConfig = try! .init()) -> () -> Void { { let benchmarkName = [ "KeywordPir", - String(describing: Scheme.self), + String(describing: PirUtil.Scheme.self), config.encryptionConfig.description, "entryCount=\(config.databaseConfig.entryCount)", "entrySize=\(config.databaseConfig.entrySizeInBytes)", @@ -382,8 +382,8 @@ public func keywordPirBenchmark( ].joined(separator: "/") Benchmark(benchmarkName, configuration: config.benchmarkConfig) { benchmark, benchmarkContext in for _ in benchmark.scaledIterations { - try blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query, - using: benchmarkContext.evaluationKey)) + try await blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query, + using: benchmarkContext.evaluationKey)) } benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize) benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount) @@ -393,7 +393,7 @@ public func keywordPirBenchmark( benchmark.measurement(.responseCiphertextCount, benchmarkContext.responseCiphertextCount) benchmark.measurement(.noiseBudget, benchmarkContext.noiseBudget) } setup: { - try await KeywordPirBenchmarkContext, MulPirClient>( + try await KeywordPirBenchmarkContext, MulPirClient>( config: config) } } diff --git a/Sources/_BenchmarkUtilities/PnnsBenchmarkUtilities.swift b/Sources/_BenchmarkUtilities/PnnsBenchmarkUtilities.swift index 85cc2e44..6735468e 100644 --- a/Sources/_BenchmarkUtilities/PnnsBenchmarkUtilities.swift +++ b/Sources/_BenchmarkUtilities/PnnsBenchmarkUtilities.swift @@ -111,7 +111,7 @@ public func pnnsProcessBenchmark( benchmark, benchmarkContext: PnnsProcessBenchmarkContext) in for _ in benchmark.scaledIterations { - try blackHole(benchmarkContext.database + try await blackHole(benchmarkContext.database .process( config: benchmarkContext.serverConfig, contexts: benchmarkContext.contexts)) @@ -146,7 +146,7 @@ public func cosineSimilarityBenchmark(_: Scheme.Type, benchmark, benchmarkContext: PnnsBenchmarkContext) in for _ in benchmark.scaledIterations { - try blackHole( + try await blackHole( benchmarkContext.server.computeResponse( to: benchmarkContext.query, using: benchmarkContext.evaluationKey)) @@ -177,7 +177,7 @@ extension PrivateNearestNeighborSearch.Response { struct PnnsProcessBenchmarkContext { let database: Database - let contexts: [Context] + let contexts: [Scheme.Context] let serverConfig: ServerConfig init(databaseConfig: PnnsDatabaseConfig, @@ -191,6 +191,7 @@ struct PnnsProcessBenchmarkContext { significantBitCounts: encryptionConfig.coefficientModulusBits, preferringSmall: false, nttDegree: encryptionConfig.polyDegree) + let encryptionParameters = try EncryptionParameters( polyDegree: encryptionConfig.polyDegree, plaintextModulus: plaintextModuli[0], @@ -226,7 +227,7 @@ struct PnnsProcessBenchmarkContext { self.database = getDatabaseForTesting(config: databaseConfig) self.contexts = try serverConfig.encryptionParameters.map { encryptionParameters in - try Context(encryptionParameters: encryptionParameters) + try Scheme.Context(encryptionParameters: encryptionParameters) } } } @@ -293,8 +294,8 @@ struct PnnsBenchmarkContext { let database = getDatabaseForTesting(config: databaseConfig) let contexts = try clientConfig.encryptionParameters - .map { encryptionParameters in try Context(encryptionParameters: encryptionParameters) } - self.processedDatabase = try database.process(config: serverConfig, contexts: contexts) + .map { encryptionParameters in try Scheme.Context(encryptionParameters: encryptionParameters) } + self.processedDatabase = try await database.process(config: serverConfig, contexts: contexts) self.client = try Client(config: clientConfig, contexts: contexts) self.server = try Server(database: processedDatabase) self.secretKey = try client.generateSecretKey() @@ -305,7 +306,7 @@ struct PnnsBenchmarkContext { let queryVectors = Array2d(data: database.rows.prefix(queryCount).map { row in row.vector }) self.query = try client.generateQuery(for: queryVectors, using: secretKey) - let response = try server.computeResponse(to: query, using: evaluationKey) + let response = try await server.computeResponse(to: query, using: evaluationKey) let decrypted = try client.decrypt(response: response, using: secretKey) // Validate correctness diff --git a/Sources/_HomomorphicEncryptionExtras/HeScheme.swift b/Sources/_HomomorphicEncryptionExtras/HeScheme.swift index 5b87165f..f9f88ebd 100644 --- a/Sources/_HomomorphicEncryptionExtras/HeScheme.swift +++ b/Sources/_HomomorphicEncryptionExtras/HeScheme.swift @@ -27,19 +27,19 @@ extension HeScheme { return } - guard let galoisKey = evaluationKey.galoisKey else { + guard let galoisKey = evaluationKey._galoisKey else { throw HeError.missingGaloisKey } // Short-circuit to single rotation if possible. let degree = ciphertext.degree let galoisElement = try GaloisElement.rotatingColumns(by: step, degree: degree) - if galoisKey.keys.keys.contains(galoisElement) { + if galoisKey._keys.keys.contains(galoisElement) { try await rotateColumnsAsync(of: &ciphertext, by: step, using: evaluationKey) return } - let galoisElements = Array(galoisKey.keys.keys) + let galoisElements = Array(galoisKey._keys.keys) let steps = GaloisElement.stepsFor(elements: galoisElements, degree: degree).values.compactMap(\.self) let positiveStep = if step < 0 { @@ -48,7 +48,7 @@ extension HeScheme { step } - let plan = try GaloisElement.planMultiStep(supportedSteps: steps, step: positiveStep, degree: degree) + let plan = try GaloisElement._planMultiStep(supportedSteps: steps, step: positiveStep, degree: degree) guard let plan else { throw HeError.invalidRotationStep(step: step, degree: degree) } @@ -71,19 +71,19 @@ extension HeScheme { return } - guard let galoisKey = evaluationKey.galoisKey else { + guard let galoisKey = evaluationKey._galoisKey else { throw HeError.missingGaloisKey } // Short-circuit to single rotation if possible. let degree = ciphertext.degree let galoisElement = try GaloisElement.rotatingColumns(by: step, degree: degree) - if galoisKey.keys.keys.contains(galoisElement) { + if galoisKey._keys.keys.contains(galoisElement) { try ciphertext.rotateColumns(by: step, using: evaluationKey) return } - let galoisElements = Array(galoisKey.keys.keys) + let galoisElements = Array(galoisKey._keys.keys) let steps = GaloisElement.stepsFor(elements: galoisElements, degree: degree).values.compactMap(\.self) let positiveStep = if step < 0 { @@ -92,7 +92,7 @@ extension HeScheme { step } - let plan = try GaloisElement.planMultiStep(supportedSteps: steps, step: positiveStep, degree: degree) + let plan = try GaloisElement._planMultiStep(supportedSteps: steps, step: positiveStep, degree: degree) guard let plan else { throw HeError.invalidRotationStep(step: step, degree: degree) } diff --git a/Sources/_TestUtilities/HeApiTestUtils.swift b/Sources/_TestUtilities/HeApiTestUtils.swift index d8eff79b..6a4916f1 100644 --- a/Sources/_TestUtilities/HeApiTestUtils.swift +++ b/Sources/_TestUtilities/HeApiTestUtils.swift @@ -22,7 +22,7 @@ public enum HeAPITestHelpers { /// Test environment with plaintexts and ciphertexts ready for use public struct TestEnv { /// Context for testing. - public let context: Context + public let context: Scheme.Context /// Raw data for first plaintext/ciphertext public let data1: [Scheme.Scalar] /// Raw data fro second plaintext/ciphertext @@ -48,7 +48,7 @@ public enum HeAPITestHelpers { /// Create the test environment. public init( - context: Context, + context: Scheme.Context, format: EncodeFormat, galoisElements: [Int] = [], relinearizationKey: Bool = false) throws @@ -77,6 +77,7 @@ public enum HeAPITestHelpers { } /// Check if the ciphertext decrypts to the expected result. + @inlinable public func checkDecryptsDecodes( ciphertext: Ciphertext, format: EncodeFormat, @@ -100,7 +101,8 @@ public enum HeAPITestHelpers { } } - /// generate the coefficient moduli for test + /// Generate the coefficient moduli for test + @inlinable public static func testCoefficientModuli() throws -> [T] { // Avoid assumptions on ordering of moduli // Also test `T.bitWidth - 2 @@ -119,17 +121,19 @@ public enum HeAPITestHelpers { preconditionFailure("Unsupported scalar type \(T.self)") } - /// generate the context for test - public static func getTestContext() throws -> Context { - try Context(encryptionParameters: EncryptionParameters( + /// Generate the context for test + @inlinable + public static func getTestContext() throws -> Context { + try Context(encryptionParameters: EncryptionParameters( polyDegree: TestUtils.testPolyDegree, - plaintextModulus: Scheme.Scalar(TestUtils.testPlaintextModulus), + plaintextModulus: Scalar(TestUtils.testPlaintextModulus), coefficientModuli: testCoefficientModuli(), errorStdDev: ErrorStdDev.stdDev32, securityLevel: SecurityLevel.unchecked)) } - /// test the evaluation key configuration + /// Test the evaluation key configuration + @inlinable public static func schemeEvaluationKeyTest(context _: Context) throws { do { let config = EvaluationKeyConfig() @@ -153,10 +157,11 @@ public enum HeAPITestHelpers { @inlinable static func encodingTest( - context: Context, + context: Scheme.Context, encodeFormat: EncodeFormat, polyFormat: (some PolyFormat).Type, - valueCount: Int) throws + valueCount: Int, + scheme _: Scheme.Type) throws { guard context.supportsSimdEncoding || encodeFormat != .simd else { return @@ -206,17 +211,19 @@ public enum HeAPITestHelpers { let bounds = -(signedModulus >> 1)...((signedModulus - 1) >> 1) signedData[0] = (Scheme.SignedScalar(context.plaintextModulus) - 1) / 2 + 1 #expect(throws: HeError.encodingDataOutOfBounds(bounds).self) { - try context.encode(signedValues: signedData, format: encodeFormat) + try context.encode(signedValues: signedData, + format: encodeFormat) as Plaintext } signedData[0] = -Scheme.SignedScalar(context.plaintextModulus) / 2 - 1 #expect(throws: HeError.encodingDataOutOfBounds(bounds).self) { - try context.encode(signedValues: signedData, format: encodeFormat) + try context.encode(signedValues: signedData, + format: encodeFormat) as Plaintext } } /// Testing the encoding/decoding functions of the scheme. @inlinable - public static func schemeEncodeDecodeTest(context: Context) throws { + public static func schemeEncodeDecodeTest(context: Scheme.Context, scheme: Scheme.Type) throws { for encodeFormat in EncodeFormat.allCases { for polyFormat: PolyFormat.Type in [Coeff.self, Eval.self] { for valueCount in [context.degree / 2, context.degree] { @@ -224,7 +231,8 @@ public enum HeAPITestHelpers { context: context, encodeFormat: encodeFormat, polyFormat: polyFormat, - valueCount: valueCount) + valueCount: valueCount, + scheme: scheme) } } } @@ -232,8 +240,11 @@ public enum HeAPITestHelpers { /// Testing the encryption and decryption of the scheme. @inlinable - public static func schemeEncryptDecryptTest(context: Context) throws { - let testEnv = try TestEnv(context: context, format: .coefficient) + public static func schemeEncryptDecryptTest( + context: Scheme.Context, + scheme _: Scheme.Type) throws + { + let testEnv = try TestEnv(context: context, format: .coefficient) var ciphertext1 = testEnv.ciphertext1 let evalCiphertext: Ciphertext = try ciphertext1.convertToEvalFormat() @@ -250,8 +261,11 @@ public enum HeAPITestHelpers { /// Testing zero-ciphertext generation of the scheme. @inlinable - public static func schemeEncryptZeroDecryptTest(context: Context) throws { - let testEnv = try TestEnv(context: context, format: .coefficient) + public static func schemeEncryptZeroDecryptTest( + context: Scheme.Context, + scheme _: Scheme.Type) throws + { + let testEnv = try TestEnv(context: context, format: .coefficient) let zeros = [Scheme.Scalar](repeating: 0, count: context.degree) let coeffCiphertext = try Ciphertext.zero(context: context) @@ -276,8 +290,11 @@ public enum HeAPITestHelpers { /// Testing addition with zero-ciphertext of the scheme. @inlinable - public static func schemeEncryptZeroAddDecryptTest(context: Context) throws { - let testEnv = try TestEnv(context: context, format: .coefficient) + public static func schemeEncryptZeroAddDecryptTest( + context: Scheme.Context, + scheme _: Scheme.Type) throws + { + let testEnv = try TestEnv(context: context, format: .coefficient) let expected = [Scheme.Scalar](repeating: 0, count: context.degree) let zeroCoeffCiphertext = try Ciphertext.zero(context: context) @@ -300,8 +317,11 @@ public enum HeAPITestHelpers { /// Testing multiplication with zero-ciphertext of the scheme. @inlinable - public static func schemeEncryptZeroMultiplyDecryptTest(context: Context) throws { - let testEnv = try TestEnv(context: context, format: .coefficient) + public static func schemeEncryptZeroMultiplyDecryptTest( + context: Scheme.Context, + scheme _: Scheme.Type) throws + { + let testEnv = try TestEnv(context: context, format: .coefficient) let expected = [Scheme.Scalar](repeating: 0, count: context.degree) let zeroCiphertext = try Ciphertext.zero(context: context) @@ -313,18 +333,21 @@ public enum HeAPITestHelpers { /// Testing ciphertext addition of the scheme. @inlinable - public static func schemeCiphertextAddTest(context: Context) async throws { - let testEnv = try TestEnv(context: context, format: .coefficient) + public static func schemeCiphertextAdditionTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws + { + let testEnv = try TestEnv(context: context, format: .coefficient) let data1 = testEnv.data1 let data2 = testEnv.data2 let sumData = zip(data1, data2).map { x, y in x.addMod(y, modulus: context.plaintextModulus) } let canonicalCipher1 = testEnv.ciphertext1 let canonicalCipher2 = testEnv.ciphertext2 - let evalCipher1 = try canonicalCipher1.convertToEvalFormat() - let evalCipher2 = try canonicalCipher2.convertToEvalFormat() - let coeffCipher1 = try canonicalCipher1.convertToCoeffFormat() - let coeffCipher2 = try canonicalCipher2.convertToCoeffFormat() + let evalCipher1 = try await canonicalCipher1.convertToEvalFormat() + let evalCipher2 = try await canonicalCipher2.convertToEvalFormat() + let coeffCipher1 = try await canonicalCipher1.convertToCoeffFormat() + let coeffCipher2 = try await canonicalCipher2.convertToCoeffFormat() // canonicalCiphertext do { @@ -417,18 +440,21 @@ public enum HeAPITestHelpers { /// Testing ciphertext subtraction of the scheme. @inlinable - public static func schemeCiphertextSubtractTest(context: Context) async throws { - let testEnv = try TestEnv(context: context, format: .coefficient) + public static func schemeCiphertextSubtractionTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws + { + let testEnv = try TestEnv(context: context, format: .coefficient) let data1 = testEnv.data1 let data2 = testEnv.data2 let diffData = zip(data1, data2).map { x, y in x.subtractMod(y, modulus: context.plaintextModulus) } let canonicalCipher1 = testEnv.ciphertext1 let canonicalCipher2 = testEnv.ciphertext2 - let evalCipher1 = try canonicalCipher1.convertToEvalFormat() - let evalCipher2 = try canonicalCipher2.convertToEvalFormat() - let coeffCipher1 = try canonicalCipher1.convertToCoeffFormat() - let coeffCipher2 = try canonicalCipher2.convertToCoeffFormat() + let evalCipher1 = try await canonicalCipher1.convertToEvalFormat() + let evalCipher2 = try await canonicalCipher2.convertToEvalFormat() + let coeffCipher1 = try await canonicalCipher1.convertToCoeffFormat() + let coeffCipher2 = try await canonicalCipher2.convertToCoeffFormat() // canonicalCiphertext do { @@ -520,15 +546,16 @@ public enum HeAPITestHelpers { } } - /// testing ciphertext multiplication of the scheme. + /// Testing ciphertext multiplication of the scheme. @inlinable - public static func schemeCiphertextCiphertextMultiplyTest( - context: Context) async throws + public static func schemeCiphertextCiphertextMultiplicationTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws { guard context.supportsSimdEncoding, context.supportsEvaluationKey else { return } - let testEnv = try TestEnv(context: context, format: .simd, relinearizationKey: true) + let testEnv = try TestEnv(context: context, format: .simd, relinearizationKey: true) let data1 = testEnv.data1 let data2 = testEnv.data2 let productData = zip(data1, data2) @@ -546,9 +573,9 @@ public enum HeAPITestHelpers { #expect(relinearizedProd.polys.count == Scheme.freshCiphertextPolyCount) #expect(relinearizedProdAsync.polys.count == Scheme.freshCiphertextPolyCount) - let evalCiphertext: Ciphertext = try ciphertextProduct.convertToEvalFormat() + let evalCiphertext: Ciphertext = try await ciphertextProduct.convertToEvalFormat() let coeffCiphertext: Ciphertext = try evalCiphertext.inverseNtt() - let evalRelinearizedCiphertext: Ciphertext = try relinearizedProd.convertToEvalFormat() + let evalRelinearizedCiphertext: Ciphertext = try await relinearizedProd.convertToEvalFormat() let coeffRelinearizedCiphertext: Ciphertext = try evalRelinearizedCiphertext.inverseNtt() try testEnv.checkDecryptsDecodes(ciphertext: coeffCiphertext, format: .simd, expected: productData) @@ -562,9 +589,10 @@ public enum HeAPITestHelpers { /// Testing CT-PT inner product of the scheme. @inlinable public static func schemeCiphertextPlaintextInnerProductTest( - context: Context) async throws + context: Scheme.Context, + scheme _: Scheme.Type) async throws { - let testEnv = try TestEnv(context: context, format: .simd) + let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 let data2 = testEnv.data2 for count in [4, 1257] { @@ -633,9 +661,10 @@ public enum HeAPITestHelpers { /// Testing CT-CT inner product of the scheme. @inlinable public static func schemeCiphertextCiphertextInnerProductTest( - context: Context) async throws + context: Scheme.Context, + scheme _: Scheme.Type) async throws { - let testEnv = try TestEnv(context: context, format: .simd) + let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 let data2 = testEnv.data2 for count in [4, 257] { @@ -655,11 +684,15 @@ public enum HeAPITestHelpers { } /// Testing CT-CT multiplication followed by CT-CT addition of the scheme. - public static func schemeCiphertextMultiplyAddTest(context: Context) async throws { + @inlinable + public static func schemeCiphertextMultiplyAddTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws + { guard context.supportsSimdEncoding else { return } - let testEnv = try TestEnv(context: context, format: .simd) + let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 let data2 = testEnv.data2 let multiplyAddData = zip(data1, data2).map { data1, data2 in @@ -675,7 +708,7 @@ public enum HeAPITestHelpers { try await Scheme.mulAssignAsync(&ciphertextResultAsync, ciphertext2) try await Scheme.addAssignAsync(&ciphertextResultAsync, ciphertext1) - let evalCiphertext: Ciphertext = try ciphertextResult.convertToEvalFormat() + let evalCiphertext: Ciphertext = try await ciphertextResult.convertToEvalFormat() let coeffCiphertext: Ciphertext = try evalCiphertext.inverseNtt() try testEnv.checkDecryptsDecodes(ciphertext: coeffCiphertext, format: .simd, expected: multiplyAddData) @@ -686,11 +719,14 @@ public enum HeAPITestHelpers { /// Testing CT-CT multiplication followed by CT-PT addition of the scheme. @inlinable - public static func schemeCiphertextMultiplyAddPlainTest(context: Context) throws { + public static func schemeCiphertextMultiplyAddPlainTest( + context: Scheme.Context, + scheme _: Scheme.Type) throws + { guard context.supportsSimdEncoding else { return } - let testEnv = try TestEnv(context: context, format: .simd) + let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 let data2 = testEnv.data2 let multiplyAddData = zip(data1, data2).map { data1, data2 in @@ -711,13 +747,16 @@ public enum HeAPITestHelpers { } /// Testing CT-CT multiplication followed by CT-PT subtraction of the scheme. + @inlinable public static func schemeCiphertextMultiplySubtractPlainTest( - context: Context) async throws + context: Scheme.Context, + scheme _: Scheme.Type) async throws + { guard context.supportsSimdEncoding else { return } - let testEnv = try TestEnv(context: context, format: .simd) + let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 let data2 = testEnv.data2 let multiplySubtractData = zip(data1, data2).map { data1, data2 in @@ -733,7 +772,7 @@ public enum HeAPITestHelpers { try await Scheme.mulAssignAsync(&ciphertextResultAsync, ciphertext2) try await Scheme.subAssignAsync(&ciphertextResultAsync, testEnv.coeffPlaintext1) - let evalCiphertext: Ciphertext = try ciphertextResult.convertToEvalFormat() + let evalCiphertext: Ciphertext = try await ciphertextResult.convertToEvalFormat() let coeffCiphertext: Ciphertext = try evalCiphertext.inverseNtt() try testEnv.checkDecryptsDecodes(ciphertext: coeffCiphertext, format: .simd, expected: multiplySubtractData) @@ -748,12 +787,13 @@ public enum HeAPITestHelpers { /// Testing CT-PT multiplication followed by CT-PT addition of the scheme. @inlinable public static func schemeCiphertextPlaintextMultiplyAddPlainTest( - context: Context) async throws + context: Scheme.Context, + scheme _: Scheme.Type) async throws { guard context.supportsSimdEncoding else { return } - let testEnv = try TestEnv(context: context, format: .simd) + let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 let data2 = testEnv.data2 let multiplyAddData = zip(data1, data2).map { data1, data2 in @@ -768,10 +808,10 @@ public enum HeAPITestHelpers { var ciphertextEvalResultAsync = ciphertext1 try await Scheme.mulAssignAsync(&ciphertextEvalResultAsync, testEnv.evalPlaintext2) - var ciphertextResultAsync = try await Scheme.inverseNttAsync(ciphertextEvalResultAsync) + var ciphertextResultAsync = try await Scheme.inverseNttAsync(&ciphertextEvalResultAsync) try await Scheme.addAssignCoeffAsync(&ciphertextResultAsync, testEnv.coeffPlaintext1) - let evalCiphertext: Ciphertext = try ciphertextResult.convertToEvalFormat() + let evalCiphertext: Ciphertext = try await ciphertextResult.convertToEvalFormat() let coeffCiphertext: Ciphertext = try evalCiphertext.inverseNtt() try testEnv.checkDecryptsDecodes(ciphertext: coeffCiphertext, format: .simd, expected: multiplyAddData) @@ -780,15 +820,16 @@ public enum HeAPITestHelpers { try testEnv.checkDecryptsDecodes(ciphertext: ciphertextResultAsync, format: .simd, expected: multiplyAddData) } - /// Testing CT-PT multiplication followed by CT-PT subtraction of the scheme. + /// Testing CT-PT multiplication followed by CT-PT addition of the scheme. @inlinable public static func schemeCiphertextPlaintextMultiplySubtractPlainTest( - context: Context) async throws + context: Scheme.Context, + scheme _: Scheme.Type) async throws { guard context.supportsSimdEncoding else { return } - let testEnv = try TestEnv(context: context, format: .simd) + let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 let data2 = testEnv.data2 let multiplySubtractData = zip(data1, data2).map { data1, data2 in @@ -803,10 +844,10 @@ public enum HeAPITestHelpers { var ciphertextEvalResultAsync = ciphertext1 try await Scheme.mulAssignAsync(&ciphertextEvalResultAsync, testEnv.evalPlaintext2) - var ciphertextResultAsync = try await Scheme.inverseNttAsync(ciphertextEvalResultAsync) + var ciphertextResultAsync = try await Scheme.inverseNttAsync(&ciphertextEvalResultAsync) try await Scheme.subAssignCoeffAsync(&ciphertextResultAsync, testEnv.coeffPlaintext1) - let evalCiphertext: Ciphertext = try ciphertextResult.convertToEvalFormat() + let evalCiphertext: Ciphertext = try await ciphertextResult.convertToEvalFormat() let coeffCiphertext: Ciphertext = try evalCiphertext.inverseNtt() try testEnv.checkDecryptsDecodes(ciphertext: coeffCiphertext, format: .simd, expected: multiplySubtractData) @@ -820,11 +861,14 @@ public enum HeAPITestHelpers { /// Testing CT-CT multiplication followed by CT-CT subtraction of the scheme. @inlinable - public static func schemeCiphertextMultiplySubtractTest(context: Context) async throws { + public static func schemeCiphertextMultiplySubTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws + { guard context.supportsSimdEncoding else { return } - let testEnv = try TestEnv(context: context, format: .simd) + let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 let data2 = testEnv.data2 let multiplySubtractData = zip(data1, data2).map { data1, data2 in @@ -840,7 +884,7 @@ public enum HeAPITestHelpers { try await Scheme.mulAssignAsync(&ciphertextResultAsync, ciphertext2) try await Scheme.subAssignAsync(&ciphertextResultAsync, ciphertext1) - let evalCiphertext: Ciphertext = try ciphertextResult.convertToEvalFormat() + let evalCiphertext: Ciphertext = try await ciphertextResult.convertToEvalFormat() let coeffCiphertext: Ciphertext = try evalCiphertext.inverseNtt() try testEnv.checkDecryptsDecodes(ciphertext: coeffCiphertext, format: .simd, expected: multiplySubtractData) @@ -854,8 +898,11 @@ public enum HeAPITestHelpers { /// Testing ciphertext negation of the scheme. @inlinable - public static func schemeCiphertextNegateTest(context: Context) async throws { - let testEnv = try TestEnv(context: context, format: .coefficient) + public static func schemeCiphertextNegateTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws + { + let testEnv = try TestEnv(context: context, format: .coefficient) let negatedData = testEnv.data1.map { data1 in data1.negateMod(modulus: context.plaintextModulus) } @@ -881,16 +928,19 @@ public enum HeAPITestHelpers { /// Testing CT-PT addition of the scheme. @inlinable - public static func schemeCiphertextPlaintextAddTest(context: Context) async throws { + public static func schemeCiphertextPlaintextAdditionTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws + { guard context.supportsSimdEncoding else { return } - let testEnv = try TestEnv(context: context, format: .simd) + let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 let data2 = testEnv.data2 let sumData = zip(data1, data2).map { x, y in x.addMod(y, modulus: context.plaintextModulus) } let canonicalCiphertext = testEnv.ciphertext1 - let evalCiphertext: Ciphertext = try canonicalCiphertext.convertToEvalFormat() + let evalCiphertext: Ciphertext = try await canonicalCiphertext.convertToEvalFormat() let coeffCiphertext: Ciphertext = try evalCiphertext.inverseNtt() let coeffPlaintext = testEnv.coeffPlaintext2 let evalPlaintext = try coeffPlaintext.forwardNtt() @@ -999,19 +1049,20 @@ public enum HeAPITestHelpers { /// Testing CT-PT subtraction of the scheme. @inlinable - public static func schemeCiphertextPlaintextSubtractTest( - context: Context) async throws + public static func schemeCiphertextPlaintextSubtractionTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws { guard context.supportsSimdEncoding else { return } - let testEnv = try TestEnv(context: context, format: .simd) + let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 let data2 = testEnv.data2 let diff1Minus2Data = zip(data1, data2).map { x, y in x.subtractMod(y, modulus: context.plaintextModulus) } let diff2Minus1Data = zip(data2, data1).map { x, y in x.subtractMod(y, modulus: context.plaintextModulus) } let canonicalCiphertext = testEnv.ciphertext1 - let evalCiphertext: Ciphertext = try canonicalCiphertext.convertToEvalFormat() + let evalCiphertext: Ciphertext = try await canonicalCiphertext.convertToEvalFormat() let coeffCiphertext: Ciphertext = try evalCiphertext.inverseNtt() let coeffPlaintext = testEnv.coeffPlaintext2 let evalPlaintext = try coeffPlaintext.forwardNtt() @@ -1114,13 +1165,14 @@ public enum HeAPITestHelpers { /// Testing CT-PT multiplication of the scheme. @inlinable - public static func schemeCiphertextPlaintextMultiplyTest( - context: Context) async throws + public static func schemeCiphertextPlaintextMultiplicationTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws { guard context.supportsSimdEncoding else { return } - let testEnv = try TestEnv(context: context, format: .simd) + let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 let data2 = testEnv.data2 var productData = [Scheme.Scalar](repeating: 0, count: context.degree) @@ -1144,11 +1196,11 @@ public enum HeAPITestHelpers { if context.coefficientModuli.count > 2 { var ciphertext = testEnv.ciphertext1 try ciphertext.modSwitchDown() - let evalCiphertext = try ciphertext.convertToEvalFormat() + let evalCiphertext = try await ciphertext.convertToEvalFormat() let evalPlaintext = try testEnv.context.encode( values: testEnv.data2, format: .simd, - moduliCount: evalCiphertext.moduli.count) + moduliCount: evalCiphertext.moduli.count) as Plaintext try testEnv.checkDecryptsDecodes( ciphertext: evalCiphertext * evalPlaintext, format: .simd, @@ -1156,7 +1208,7 @@ public enum HeAPITestHelpers { var ciphertextAsync = testEnv.ciphertext1 try await Scheme.modSwitchDownAsync(&ciphertextAsync) - var evalCiphertextAsync = try ciphertextAsync.convertToEvalFormat() + var evalCiphertextAsync = try await ciphertextAsync.convertToEvalFormat() try await Scheme.mulAssignAsync(&evalCiphertextAsync, evalPlaintext) try testEnv.checkDecryptsDecodes( ciphertext: evalCiphertextAsync, @@ -1172,10 +1224,13 @@ public enum HeAPITestHelpers { /// Testing ciphertext rotation of the scheme. @inlinable - public static func schemeRotationTest(context: Context) async throws { - func runRotationTest(context: Context, galoisElements: [Int], multiStep: Bool) async throws { + public static func schemeRotationTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws + { + func runRotationTest(context: Scheme.Context, galoisElements: [Int], multiStep: Bool) async throws { let degree = context.degree - let testEnv = try TestEnv(context: context, format: .simd, galoisElements: galoisElements) + let testEnv = try TestEnv(context: context, format: .simd, galoisElements: galoisElements) let evaluationKey = try #require(testEnv.evaluationKey) for step in 1..(context: context, format: .simd, galoisElements: galoisElementsSwap) let evaluationKey = try #require(testEnv.evaluationKey) let expectedData = Array(testEnv.data1[degree / 2..(context: Context) async throws { + public static func schemeApplyGaloisTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws + { guard context.supportsSimdEncoding, context.supportsEvaluationKey else { return } + func rotate(original: [Scheme.Scalar], halfDataCount: Int, step: Int) -> [Scheme.Scalar] { + let slice0 = Array(original[step..> 1)).map { step in try GaloisElement.rotatingColumns(by: -step, degree: context.degree) } - let testEnv = try TestEnv(context: context, format: .simd, galoisElements: elements) + + let testEnv = try TestEnv(context: context, format: .simd, galoisElements: elements) let evaluationKey = try #require(testEnv.evaluationKey) let dataCount = testEnv.data1.count let halfDataCount = dataCount / 2 - let rotate = { (original: [Scheme.Scalar], step: Int) -> [Scheme.Scalar] in - Array(original[step..(context: Context) throws { - let testEnv = try TestEnv(context: context, format: .coefficient) + /// Testing noise budget estimation. + public static func noiseBudgetTest(context: Scheme.Context, scheme _: Scheme.Type) throws { + let testEnv = try TestEnv(context: context, format: .coefficient) let zeroCoeffCiphertext = try Scheme.CoeffCiphertext.zero(context: context, moduliCount: 1) #expect(try zeroCoeffCiphertext.noiseBudget(using: testEnv.secretKey, variableTime: true) == Double.infinity) @@ -1338,14 +1399,17 @@ public enum HeAPITestHelpers { #expect(decrypted != expected) } - /// testing repeated addition. + /// Testing repeated addition. @inlinable - public static func repeatedAdditionTest(context: Context) async throws { + public static func repeatAdditionTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws + { let testEnv = try HeAPITestHelpers.TestEnv(context: context, format: .coefficient) var coeffCiphertext = testEnv.ciphertext1 - var coeffCiphertextAsync = try coeffCiphertext.convertToCoeffFormat() - let coeffCifertextToAdd = try coeffCiphertext.convertToCoeffFormat() + var coeffCiphertextAsync = try await coeffCiphertext.convertToCoeffFormat() + let coeffCifertextToAdd = try await coeffCiphertext.convertToCoeffFormat() try coeffCiphertext += testEnv.coeffPlaintext1 try coeffCiphertext += testEnv.ciphertext1 try coeffCiphertext += testEnv.coeffPlaintext1 @@ -1363,12 +1427,15 @@ public enum HeAPITestHelpers { try testEnv.checkDecryptsDecodes(ciphertext: coeffCiphertextAsync, format: .coefficient, expected: expected) } - /// testing multiply inverse power of x. + /// Testing multiply inverse power of x. @inlinable - public static func multiplyInverseTest(context: Context) async throws { + public static func multiplyInverseTest( + context: Scheme.Context, + scheme _: Scheme.Type) async throws + { let testEnv = try HeAPITestHelpers.TestEnv(context: context, format: .coefficient) - var coeffCiphertext1 = try testEnv.ciphertext1.convertToCoeffFormat() + var coeffCiphertext1 = try await testEnv.ciphertext1.convertToCoeffFormat() var coeffCiphertext2 = coeffCiphertext1 var coeffCiphertext3 = coeffCiphertext1 var coeffCiphertext4 = coeffCiphertext1 @@ -1391,4 +1458,69 @@ public enum HeAPITestHelpers { try testEnv.checkDecryptsDecodes(ciphertext: coeffCiphertext3, format: .coefficient, expected: expectedData1) try testEnv.checkDecryptsDecodes(ciphertext: coeffCiphertext4, format: .coefficient, expected: expectedData2) } + + /// Testing ntt and intt on ciphertexts. + @inlinable + public static func schemeTestNtt(context: Scheme.Context, scheme _: Scheme.Type) async throws { + let testEnv = try HeAPITestHelpers.TestEnv(context: context, format: .coefficient) + var ciphertext = try await testEnv.ciphertext1.convertToCoeffFormat() + + var evalCiphertext = try Scheme.forwardNtt(&ciphertext) + let evalCiphertextAsync = try await Scheme.forwardNttAsync(&ciphertext) + let decrypted1: [Scheme.Scalar] = try evalCiphertext.decrypt(using: testEnv.secretKey) + .decode(format: .coefficient) + let decrypted2: [Scheme.Scalar] = try evalCiphertextAsync.decrypt(using: testEnv.secretKey) + .decode(format: .coefficient) + #expect(decrypted1 == decrypted2) + + let coeffCiphertext = try Scheme.inverseNtt(&evalCiphertext) + let coeffCiphertextAsync = try await Scheme.inverseNttAsync(&evalCiphertext) + let decrypted3: [Scheme.Scalar] = try coeffCiphertext.decrypt(using: testEnv.secretKey) + .decode(format: .coefficient) + let decrypted4: [Scheme.Scalar] = try coeffCiphertextAsync.decrypt(using: testEnv.secretKey) + .decode(format: .coefficient) + #expect(decrypted3 == decrypted4) + } + + /// Testing ciphertext format conversion. + @inlinable + public static func schemeTestFormats(context: Scheme.Context, + scheme _: Scheme.Type) async throws + { + let testEnv = try HeAPITestHelpers.TestEnv(context: context, format: .coefficient) + let ciphertext = testEnv.ciphertext1 + // swiftlint:disable large_tuple + func synchronousConversions(_ ciphertext: Ciphertext) throws + -> ([Scheme.Scalar], [Scheme.Scalar], [Scheme.Scalar]) + { + let evalCiphertext = try ciphertext.convertToEvalFormat() + let canonicalCiphertext = try evalCiphertext.convertToCanonicalFormat() + let coeffCiphertext = try evalCiphertext.convertToCoeffFormat() + let decryptedEval: [Scheme.Scalar] = try evalCiphertext.decrypt(using: testEnv.secretKey) + .decode(format: .coefficient) + let decryptedCanonical: [Scheme.Scalar] = try canonicalCiphertext.decrypt(using: testEnv.secretKey) + .decode(format: .coefficient) + let decryptedCoeff: [Scheme.Scalar] = try coeffCiphertext.decrypt(using: testEnv.secretKey) + .decode(format: .coefficient) + return (decryptedEval, decryptedCanonical, decryptedCoeff) + } + // swiftlint:enable large_tuple + + let syncDecrypted = try synchronousConversions(ciphertext) + + let evalCiphertext = try await ciphertext.convertToEvalFormat() + let decryptedEval: [Scheme.Scalar] = try evalCiphertext.decrypt(using: testEnv.secretKey) + .decode(format: .coefficient) + #expect(decryptedEval == syncDecrypted.0) + + let canonicalCiphertext = try await evalCiphertext.convertToCanonicalFormat() + let decryptedCanonical: [Scheme.Scalar] = try canonicalCiphertext.decrypt(using: testEnv.secretKey) + .decode(format: .coefficient) + #expect(decryptedCanonical == syncDecrypted.1) + + let coeffCiphertext = try await evalCiphertext.convertToCoeffFormat() + let decryptedCoeff: [Scheme.Scalar] = try coeffCiphertext.decrypt(using: testEnv.secretKey) + .decode(format: .coefficient) + #expect(decryptedCoeff == syncDecrypted.2) + } } diff --git a/Sources/_TestUtilities/PirUtilities/ExpansionTests.swift b/Sources/_TestUtilities/PirUtilities/ExpansionTests.swift index 58594a4d..438d4d49 100644 --- a/Sources/_TestUtilities/PirUtilities/ExpansionTests.swift +++ b/Sources/_TestUtilities/PirUtilities/ExpansionTests.swift @@ -24,7 +24,7 @@ extension PirTestUtils { @inlinable public static func expandCiphertextForOneStep( scheme _: Scheme.Type, - _ keyCompression: PirKeyCompressionStrategy) throws + _ keyCompression: PirKeyCompressionStrategy) async throws { let degree = 32 let significantBitCounts = Array(repeating: Scheme.Scalar.bitWidth - 4, count: 4) @@ -39,7 +39,7 @@ extension PirTestUtils { errorStdDev: ErrorStdDev.stdDev32, securityLevel: SecurityLevel.unchecked) - let context: Context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let plaintextModulus = context.plaintextModulus let logDegree = degree.log2 for logStep in 1...logDegree { @@ -59,7 +59,7 @@ extension PirTestUtils { let evaluationKey = try context.generateEvaluationKey( config: EvaluationKeyConfig, using: secretKey) let ciphertext = try plaintext.encrypt(using: secretKey) - let expandedCiphertexts = try PirUtil.expandCiphertextForOneStep( + let expandedCiphertexts = try await PirUtil.expandCiphertextForOneStep( ciphertext, logStep: logStep, using: evaluationKey) @@ -78,22 +78,22 @@ extension PirTestUtils { /// Tests compressInputsForOneCiphertext and expandCiphertexts roundtrip. @inlinable - public static func oneCiphertextRoundtrip(scheme _: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + public static func oneCiphertextRoundtrip(scheme _: Scheme.Type) async throws { + let context: Scheme.Context = try TestUtils.getTestContext() let degree = context.degree let logDegree = degree.log2 for inputCount in 1...degree { let data: [Scheme.Scalar] = (0.. = try PirUtil.compressInputsForOneCiphertext( totalInputCount: inputCount, - nonZeroInputs: nonZeroInputs, + oneIndices: oneIndices, context: context) let secretKey = try context.generateSecretKey() let evaluationKeyConfig = EvaluationKeyConfig(galoisElements: (1...logDegree).map { (1 << $0) + 1 }) let evaluationKey = try context.generateEvaluationKey(config: evaluationKeyConfig, using: secretKey) let ciphertext = try plaintext.encrypt(using: secretKey) - let expandedCiphertexts = try PirUtil.expandCiphertext( + let expandedCiphertexts = try await PirUtil.expandCiphertext( ciphertext, outputCount: inputCount, logStep: 1, @@ -111,32 +111,34 @@ extension PirTestUtils { } } - /// Tests compressInputs and expandCiphertexts roundtrip with multiple ciphertexts. + /// Tests compressBinaryInputs and expandCiphertexts roundtrip with multiple ciphertexts. @inlinable - public static func multipleCiphertextsRoundtrip(scheme _: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + public static func multipleCiphertextsRoundtrip(pirUtil _: PirUtil + .Type) async throws + { + let context: PirUtil.Scheme.Context = try TestUtils.getTestContext() let degree = TestUtils.testPolyDegree let logDegree = degree.log2 for inputCount in 1...degree * 2 { let data: [Int] = (0..) throws + with context: Server.Scheme.Context) async throws where Server.IndexPir == Client.IndexPir { let database = PirTestUtils.randomIndexPirDatabase( entryCount: parameter.entryCount, entrySizeInBytes: parameter.entrySizeInBytes) - let processedDb = try Server.process(database: database, with: context, using: parameter) + let processedDb = try await Server.process(database: database, with: context, using: parameter) let server = try Server(parameter: parameter, context: context, database: processedDb) let client = Client(parameter: parameter, context: context) @@ -44,7 +44,7 @@ extension PirTestUtils { let batchSize = Int.random(in: 1...parameter.batchSize) let queryIndices = Array(indices.prefix(batchSize)) let query = try client.generateQuery(at: queryIndices, using: secretKey) - let response = try server.computeResponse(to: query, using: evaluationKey) + let response = try await server.computeResponse(to: query, using: evaluationKey) if Server.Scheme.self != NoOpScheme.self { #expect(!response.isTransparent()) } @@ -57,7 +57,7 @@ extension PirTestUtils { @inlinable static func indexPirTest(server: Server.Type, - client: Client.Type) throws + client: Client.Type) async throws where Server.IndexPir == Client.IndexPir { let configs = try [ @@ -99,117 +99,119 @@ extension PirTestUtils { keyCompression: .maxCompression), ] - let context: Context = try TestUtils.getTestContext() + let context: Server.Scheme.Context = try TestUtils.getTestContext() for config in configs { let parameter = Server.generateParameter(config: config, with: context) - try indexPirTestForParameter(server: server, client: client, for: parameter, with: context) + try await indexPirTestForParameter(server: server, client: client, for: parameter, with: context) } } /// Testing indexPir. @inlinable - public static func indexPir(scheme _: Scheme.Type) throws { - try indexPirTest(server: MulPirServer.self, client: MulPirClient.self) + public static func indexPir(scheme _: Scheme.Type) async throws { + try await indexPirTest( + server: MulPirServer>.self, + client: MulPirClient>.self) } + } - /// Testing client configuration. - @inlinable - func generateParameter() throws { - let context: Context> = try TestUtils.getTestContext() - // unevenDimensions: false - do { - let config = try IndexPirConfig(entryCount: 16, - entrySizeInBytes: context.bytesPerPlaintext, - dimensionCount: 2, - batchSize: 1, - unevenDimensions: false, - keyCompression: .noCompression) - let parameter = MulPir.generateParameter(config: config, with: context) - #expect(parameter.dimensions == [4, 4]) - } - do { - let config = try IndexPirConfig(entryCount: 10, - entrySizeInBytes: context.bytesPerPlaintext, - dimensionCount: 2, - batchSize: 2, - unevenDimensions: false, - keyCompression: .noCompression) - let parameter = MulPir.generateParameter(config: config, with: context) - #expect(parameter.dimensions == [4, 3]) - } - // unevenDimensions: true - do { - let config = try IndexPirConfig(entryCount: 15, - entrySizeInBytes: context.bytesPerPlaintext, - dimensionCount: 2, - batchSize: 1, - unevenDimensions: true, - keyCompression: .noCompression) - let parameter = MulPir.generateParameter(config: config, with: context) - #expect(parameter.dimensions == [5, 3]) - } - do { - let config = try IndexPirConfig(entryCount: 15, - entrySizeInBytes: context.bytesPerPlaintext, - dimensionCount: 2, - batchSize: 2, - unevenDimensions: true, - keyCompression: .noCompression) - let parameter = MulPir.generateParameter(config: config, with: context) - #expect(parameter.dimensions == [5, 3]) - } - do { - let config = try IndexPirConfig(entryCount: 17, - entrySizeInBytes: context.bytesPerPlaintext, - dimensionCount: 2, - batchSize: 2, - unevenDimensions: true, - keyCompression: .noCompression) - let parameter = MulPir.generateParameter(config: config, with: context) - #expect(parameter.dimensions == [9, 2]) - } - // no key compression - do { - let config = try IndexPirConfig(entryCount: 100, - entrySizeInBytes: context.bytesPerPlaintext, - dimensionCount: 2, - batchSize: 2, - unevenDimensions: true, - keyCompression: .noCompression) - let parameter = MulPir.generateParameter(config: config, with: context) - let evalKeyConfig = EvaluationKeyConfig( - galoisElements: [3, 5, 9, 17], - hasRelinearizationKey: true) - #expect(parameter.evaluationKeyConfig == evalKeyConfig) - } - // hybrid key compression - do { - let config = try IndexPirConfig(entryCount: 100, - entrySizeInBytes: context.bytesPerPlaintext, - dimensionCount: 2, - batchSize: 2, - unevenDimensions: true, - keyCompression: .hybridCompression) - let parameter = MulPir.generateParameter(config: config, with: context) - let evalKeyConfig = EvaluationKeyConfig( - galoisElements: [3, 5, 9, 17], - hasRelinearizationKey: true) - #expect(parameter.evaluationKeyConfig == evalKeyConfig) - } - // max key compression - do { - let config = try IndexPirConfig(entryCount: 100, - entrySizeInBytes: context.bytesPerPlaintext, - dimensionCount: 2, - batchSize: 2, - unevenDimensions: true, - keyCompression: .maxCompression) - let parameter = MulPir.generateParameter(config: config, with: context) - let evalKeyConfig = EvaluationKeyConfig( - galoisElements: [3, 5, 9], - hasRelinearizationKey: true) - #expect(parameter.evaluationKeyConfig == evalKeyConfig) - } + /// Testing client configuration. + @inlinable + func generateParameter() throws { + let context: Context> = try TestUtils.getTestContext() + // unevenDimensions: false + do { + let config = try IndexPirConfig(entryCount: 16, + entrySizeInBytes: context.bytesPerPlaintext, + dimensionCount: 2, + batchSize: 1, + unevenDimensions: false, + keyCompression: .noCompression) + let parameter = MulPir>.generateParameter(config: config, with: context) + #expect(parameter.dimensions == [4, 4]) + } + do { + let config = try IndexPirConfig(entryCount: 10, + entrySizeInBytes: context.bytesPerPlaintext, + dimensionCount: 2, + batchSize: 2, + unevenDimensions: false, + keyCompression: .noCompression) + let parameter = MulPir>.generateParameter(config: config, with: context) + #expect(parameter.dimensions == [4, 3]) + } + // unevenDimensions: true + do { + let config = try IndexPirConfig(entryCount: 15, + entrySizeInBytes: context.bytesPerPlaintext, + dimensionCount: 2, + batchSize: 1, + unevenDimensions: true, + keyCompression: .noCompression) + let parameter = MulPir>.generateParameter(config: config, with: context) + #expect(parameter.dimensions == [5, 3]) + } + do { + let config = try IndexPirConfig(entryCount: 15, + entrySizeInBytes: context.bytesPerPlaintext, + dimensionCount: 2, + batchSize: 2, + unevenDimensions: true, + keyCompression: .noCompression) + let parameter = MulPir>.generateParameter(config: config, with: context) + #expect(parameter.dimensions == [5, 3]) + } + do { + let config = try IndexPirConfig(entryCount: 17, + entrySizeInBytes: context.bytesPerPlaintext, + dimensionCount: 2, + batchSize: 2, + unevenDimensions: true, + keyCompression: .noCompression) + let parameter = MulPir>.generateParameter(config: config, with: context) + #expect(parameter.dimensions == [9, 2]) + } + // no key compression + do { + let config = try IndexPirConfig(entryCount: 100, + entrySizeInBytes: context.bytesPerPlaintext, + dimensionCount: 2, + batchSize: 2, + unevenDimensions: true, + keyCompression: .noCompression) + let parameter = MulPir>.generateParameter(config: config, with: context) + let evalKeyConfig = EvaluationKeyConfig( + galoisElements: [3, 5, 9, 17], + hasRelinearizationKey: true) + #expect(parameter.evaluationKeyConfig == evalKeyConfig) + } + // hybrid key compression + do { + let config = try IndexPirConfig(entryCount: 100, + entrySizeInBytes: context.bytesPerPlaintext, + dimensionCount: 2, + batchSize: 2, + unevenDimensions: true, + keyCompression: .hybridCompression) + let parameter = MulPir>.generateParameter(config: config, with: context) + let evalKeyConfig = EvaluationKeyConfig( + galoisElements: [3, 5, 9, 17], + hasRelinearizationKey: true) + #expect(parameter.evaluationKeyConfig == evalKeyConfig) + } + // max key compression + do { + let config = try IndexPirConfig(entryCount: 100, + entrySizeInBytes: context.bytesPerPlaintext, + dimensionCount: 2, + batchSize: 2, + unevenDimensions: true, + keyCompression: .maxCompression) + let parameter = MulPir>.generateParameter(config: config, with: context) + let evalKeyConfig = EvaluationKeyConfig( + galoisElements: [3, 5, 9], + hasRelinearizationKey: true) + #expect(parameter.evaluationKeyConfig == evalKeyConfig) } } } diff --git a/Sources/_TestUtilities/PirUtilities/KeywordPirTests.swift b/Sources/_TestUtilities/PirUtilities/KeywordPirTests.swift index b666f1e6..4c422bad 100644 --- a/Sources/_TestUtilities/PirUtilities/KeywordPirTests.swift +++ b/Sources/_TestUtilities/PirUtilities/KeywordPirTests.swift @@ -21,21 +21,21 @@ extension PirTestUtils { public enum KeywordPirTests { /// Tests database serialization. @inlinable - public static func processedDatabaseSerialization(_: Scheme.Type) throws { + public static func processedDatabaseSerialization(_: Scheme.Type) async throws { let rowCount = 100 let valueSize = 10 let testDatabase = PirTestUtils.randomKeywordPirDatabase(rowCount: rowCount, valueSize: valueSize) let encryptionParameters: EncryptionParameters = try TestUtils.getTestEncryptionParameters() - let testContext: Context = try Context(encryptionParameters: encryptionParameters) + let testContext = try Scheme.Context(encryptionParameters: encryptionParameters) let keywordConfig = try KeywordPirConfig( dimensionCount: 2, cuckooTableConfig: PirTestUtils.testCuckooTableConfig(maxSerializedBucketSize: 5 * valueSize), unevenDimensions: true, keyCompression: .noCompression) - let processed = try KeywordPirServer>.process(database: testDatabase, - config: keywordConfig, - with: testContext) + let processed = try await KeywordPirServer>>.process(database: testDatabase, + config: keywordConfig, + with: testContext) // Ensure we're testing nil plaintexts #expect(processed.database.plaintexts.contains { plaintext in plaintext == nil }) let serialized = try processed.database.serialize() @@ -50,15 +50,17 @@ extension PirTestUtils { encryptionParameters: EncryptionParameters, keywordConfig: KeywordPirConfig, server _: PirServer.Type, - client _: PirClient.Type) throws where PirServer.IndexPir == PirClient.IndexPir + client _: PirClient.Type) async throws where PirServer.IndexPir == PirClient.IndexPir { - let testContext: Context = try Context(encryptionParameters: encryptionParameters) + // swiftlint:disable:next nesting + typealias Scheme = PirServer.Scheme + let testContext = try Scheme.Context(encryptionParameters: encryptionParameters) let valueSize = testContext.bytesPerPlaintext / 2 let testDatabase = PirTestUtils.randomKeywordPirDatabase(rowCount: 100, valueSize: valueSize) - let processed = try KeywordPirServer.process(database: testDatabase, - config: keywordConfig, - with: testContext) + let processed = try await KeywordPirServer.process(database: testDatabase, + config: keywordConfig, + with: testContext) #expect(processed.pirParameter.dimensions.product() > 1, "trivial PIR") let server = try KeywordPirServer( @@ -72,8 +74,8 @@ extension PirTestUtils { let shuffledValues = Array(testDatabase.indices).shuffled() for index in shuffledValues.prefix(10) { let query = try client.generateQuery(at: testDatabase[index].keyword, using: secretKey) - let response = try server.computeResponse(to: query, using: evaluationKey) - if PirServer.Scheme.self != NoOpScheme.self { + let response = try await server.computeResponse(to: query, using: evaluationKey) + if Scheme.self != NoOpScheme.self { #expect(!response.isTransparent()) } let result = try client.decrypt(response: response, at: testDatabase[index].keyword, using: secretKey) @@ -81,8 +83,8 @@ extension PirTestUtils { } let noKey = PirTestUtils.generateRandomBytes(size: 5) let query = try client.generateQuery(at: noKey, using: secretKey) - let response = try server.computeResponse(to: query, using: evaluationKey) - if PirServer.Scheme.self != NoOpScheme.self { + let response = try await server.computeResponse(to: query, using: evaluationKey) + if Scheme.self != NoOpScheme.self { #expect(!response.isTransparent()) } let result = try client.decrypt(response: response, at: noKey, using: secretKey) @@ -91,7 +93,7 @@ extension PirTestUtils { /// Testing Keyword MulPir with 1 hash function. @inlinable - public static func keywordPirMulPir1HashFunction(_: Scheme.Type) throws { + public static func keywordPirMulPir1HashFunction(_: Scheme.Type) async throws { let cuckooTableConfig = try CuckooTableConfig( hashFunctionCount: 1, maxEvictionCount: 100, @@ -102,16 +104,16 @@ extension PirTestUtils { cuckooTableConfig: cuckooTableConfig, unevenDimensions: true, keyCompression: .noCompression) - try Self.keywordPirTest( + try await Self.keywordPirTest( encryptionParameters: TestUtils.getTestEncryptionParameters(), keywordConfig: keywordConfig, - server: MulPirServer.self, - client: MulPirClient.self) + server: MulPirServer>.self, + client: MulPirClient>.self) } /// Testing Keyword MulPir with 3 hash functions. @inlinable - public static func keywordPirMulPir3HashFunctions(_: Scheme.Type) throws { + public static func keywordPirMulPir3HashFunctions(_: Scheme.Type) async throws { let cuckooTableConfig = try CuckooTableConfig( hashFunctionCount: 3, maxEvictionCount: 100, @@ -121,76 +123,76 @@ extension PirTestUtils { dimensionCount: 2, cuckooTableConfig: cuckooTableConfig, unevenDimensions: true, keyCompression: .noCompression) - try keywordPirTest( + try await keywordPirTest( encryptionParameters: TestUtils.getTestEncryptionParameters(), keywordConfig: keywordConfig, - server: MulPirServer.self, - client: MulPirClient.self) + server: MulPirServer>.self, + client: MulPirClient>.self) } /// Testing Keyword MulPir with 1 dimension. @inlinable - public static func keywordPirMulPir1Dimension(_: Scheme.Type) throws { + public static func keywordPirMulPir1Dimension(_: Scheme.Type) async throws { let keywordConfig = try KeywordPirConfig( dimensionCount: 1, cuckooTableConfig: PirTestUtils.testCuckooTableConfig( maxSerializedBucketSize: 100), unevenDimensions: true, keyCompression: .noCompression) - try Self.keywordPirTest( + try await Self.keywordPirTest( encryptionParameters: TestUtils.getTestEncryptionParameters(), keywordConfig: keywordConfig, - server: MulPirServer.self, - client: MulPirClient.self) + server: MulPirServer>.self, + client: MulPirClient>.self) } /// Testing Keyword MulPir with 2 dimensions. @inlinable - public static func keywordPirMulPir2Dimensions(_: Scheme.Type) throws { + public static func keywordPirMulPir2Dimensions(_: Scheme.Type) async throws { let keywordConfig = try KeywordPirConfig( dimensionCount: 2, cuckooTableConfig: PirTestUtils.testCuckooTableConfig( maxSerializedBucketSize: 100), unevenDimensions: true, keyCompression: .noCompression) - try keywordPirTest( + try await keywordPirTest( encryptionParameters: TestUtils.getTestEncryptionParameters(), keywordConfig: keywordConfig, - server: MulPirServer.self, - client: MulPirClient.self) + server: MulPirServer>.self, + client: MulPirClient>.self) } /// Testing Keyword MulPir with hybrid key compression. @inlinable - public static func keywordPirMulPirHybridKeyCompression(_: Scheme.Type) throws { + public static func keywordPirMulPirHybridKeyCompression(_: Scheme.Type) async throws { let keywordConfig = try KeywordPirConfig( dimensionCount: 2, cuckooTableConfig: PirTestUtils.testCuckooTableConfig( maxSerializedBucketSize: 100), unevenDimensions: true, keyCompression: .hybridCompression) - try Self.keywordPirTest( + try await Self.keywordPirTest( encryptionParameters: TestUtils.getTestEncryptionParameters(), keywordConfig: keywordConfig, - server: MulPirServer.self, - client: MulPirClient.self) + server: MulPirServer>.self, + client: MulPirClient>.self) } /// Testing Keyword MulPir with max key compression. @inlinable - public static func keywordPirMulPirMaxKeyCompression(_: Scheme.Type) throws { + public static func keywordPirMulPirMaxKeyCompression(_: Scheme.Type) async throws { let keywordConfig = try KeywordPirConfig( dimensionCount: 2, cuckooTableConfig: PirTestUtils.testCuckooTableConfig( maxSerializedBucketSize: 100), unevenDimensions: true, keyCompression: .maxCompression) - try Self.keywordPirTest( + try await Self.keywordPirTest( encryptionParameters: TestUtils.getTestEncryptionParameters(), keywordConfig: keywordConfig, - server: MulPirServer.self, - client: MulPirClient.self) + server: MulPirServer>.self, + client: MulPirClient>.self) } /// Testing Keyword MulPir with larger parameters. @inlinable - public static func keywordPirMulPirLargeParameters(_: Scheme.Type) throws { + public static func keywordPirMulPirLargeParameters(_: Scheme.Type) async throws { if Scheme.Scalar.self == UInt32.self { let parameters = try EncryptionParameters(from: PredefinedRlweParameters .n_4096_logq_27_28_28_logt_5) @@ -199,11 +201,11 @@ extension PirTestUtils { cuckooTableConfig: PirTestUtils.testCuckooTableConfig( maxSerializedBucketSize: 3 * parameters.bytesPerPlaintext), unevenDimensions: true, keyCompression: .noCompression) - try Self.keywordPirTest( + try await Self.keywordPirTest( encryptionParameters: parameters, keywordConfig: keywordConfig, - server: MulPirServer.self, - client: MulPirClient.self) + server: MulPirServer>.self, + client: MulPirClient>.self) } else if Scheme.Scalar.self == UInt64.self, Scheme.self != NoOpScheme.self { let parameters = try EncryptionParameters(from: PredefinedRlweParameters .insecure_n_512_logq_4x60_logt_20) @@ -212,11 +214,11 @@ extension PirTestUtils { cuckooTableConfig: PirTestUtils.testCuckooTableConfig( maxSerializedBucketSize: 3 * parameters.bytesPerPlaintext), unevenDimensions: true, keyCompression: .noCompression) - try Self.keywordPirTest( + try await Self.keywordPirTest( encryptionParameters: parameters, keywordConfig: keywordConfig, - server: MulPirServer.self, - client: MulPirClient.self) + server: MulPirServer>.self, + client: MulPirClient>.self) } if Scheme.self == NoOpScheme.self { let noOpParameters = try EncryptionParameters(from: PredefinedRlweParameters @@ -226,24 +228,24 @@ extension PirTestUtils { cuckooTableConfig: PirTestUtils.testCuckooTableConfig( maxSerializedBucketSize: 3 * noOpParameters.bytesPerPlaintext), unevenDimensions: true, keyCompression: .noCompression) - try Self.keywordPirTest( + try await Self.keywordPirTest( encryptionParameters: noOpParameters, keywordConfig: keywordConfig, - server: MulPirServer.self, - client: MulPirClient.self) + server: MulPirServer>.self, + client: MulPirClient>.self) } } /// Testing Keyword Pir fixed configuration. @inlinable - public static func keywordPirFixedConfig(_: Scheme.Type) throws { + public static func keywordPirFixedConfig(_: Scheme.Type) async throws { let rowCount = 100 let valueSize = 9 let encryptionParams: EncryptionParameters = try TestUtils.getTestEncryptionParameters() - let testContext: Context = try Context(encryptionParameters: encryptionParams) + let testContext = try Scheme.Context(encryptionParameters: encryptionParams) var rng = TestRng() - let (pirParameter, keywordConfig): (IndexPirParameter, KeywordPirConfig) = try { + let (pirParameter, keywordConfig): (IndexPirParameter, KeywordPirConfig) = try await { let cuckooConfig = try CuckooTableConfig( hashFunctionCount: 2, maxEvictionCount: 100, @@ -257,9 +259,10 @@ extension PirTestUtils { rowCount: rowCount, valueSize: valueSize, using: &rng) - let processed = try KeywordPirServer>.process(database: testDatabase, - config: keywordConfig, - with: testContext) + let processed = try await KeywordPirServer>>.process( + database: testDatabase, + config: keywordConfig, + with: testContext) let newConfig = try KeywordPirConfig( dimensionCount: 2, cuckooTableConfig: cuckooConfig.freezingTableSize( @@ -274,14 +277,14 @@ extension PirTestUtils { rowCount: rowCount + 1, valueSize: valueSize - 1, using: &rng) - let processed = try KeywordPirServer>.process(database: testDatabase, - config: keywordConfig, - with: testContext) + let processed = try await KeywordPirServer>>.process(database: testDatabase, + config: keywordConfig, + with: testContext) #expect(processed.pirParameter == pirParameter) - let server = try KeywordPirServer>( + let server = try KeywordPirServer>>( context: testContext, processed: processed) - let client = KeywordPirClient>( + let client = KeywordPirClient>>( keywordParameter: keywordConfig.parameter, pirParameter: processed.pirParameter, context: testContext) @@ -290,7 +293,7 @@ extension PirTestUtils { let shuffledValues = Array(testDatabase.indices).shuffled() for index in shuffledValues.prefix(1) { let query = try client.generateQuery(at: testDatabase[index].keyword, using: secretKey) - let response = try server.computeResponse(to: query, using: evaluationKey) + let response = try await server.computeResponse(to: query, using: evaluationKey) if Scheme.self != NoOpScheme.self { #expect(!response.isTransparent()) } @@ -302,7 +305,7 @@ extension PirTestUtils { } let noKey = PirTestUtils.generateRandomBytes(size: 5) let query = try client.generateQuery(at: noKey, using: secretKey) - let response = try server.computeResponse(to: query, using: evaluationKey) + let response = try await server.computeResponse(to: query, using: evaluationKey) if Scheme.self != NoOpScheme.self { #expect(!response.isTransparent()) } @@ -312,18 +315,17 @@ extension PirTestUtils { /// Test sharding. @inlinable - public static func sharding(_: Scheme.Type) throws { + public static func sharding(_: PirUtil.Type) async throws { // swiftlint:disable nesting - typealias PirClient = MulPirClient - typealias PirServer = MulPirServer + typealias PirClient = MulPirClient + typealias PirServer = MulPirServer // swiftlint:enable nesting let rowCount = 1000 let valueSize = 10 let rlweParameters = PredefinedRlweParameters.n_4096_logq_27_28_28_logt_5 - let encryptionParameters = try EncryptionParameters(from: rlweParameters) - let testContext: Context = try Context( - encryptionParameters: encryptionParameters) + let encryptionParameters = try EncryptionParameters(from: rlweParameters) + let testContext = try PirUtil.Scheme.Context(encryptionParameters: encryptionParameters) let shardCount = 2 let cuckooConfig = try CuckooTableConfig( @@ -340,14 +342,14 @@ extension PirTestUtils { keywordPirConfig: keywordConfig) let testDatabase = PirTestUtils.randomKeywordPirDatabase(rowCount: rowCount, valueSize: valueSize) - let args = try ProcessKeywordDatabase.Arguments( + let args = try ProcessKeywordDatabase.Arguments( databaseConfig: databaseConfig, encryptionParameters: encryptionParameters, algorithm: PirAlgorithm.mulPir, keyCompression: .noCompression, trialsPerShard: 1) - let processed: ProcessKeywordDatabase.Processed = try ProcessKeywordDatabase.process( + let processed: ProcessKeywordDatabase.Processed = try await ProcessKeywordDatabase.process( rows: testDatabase, - with: args) + with: args, using: PirUtil.self) #expect(processed.shards.count == shardCount) let servers = try [String: KeywordPirServer](uniqueKeysWithValues: processed.shards @@ -374,8 +376,8 @@ extension PirTestUtils { let shardID = keyword.shardID(shardCount: shardCount) let client = try #require(clients[shardID]) let query = try client.generateQuery(at: testDatabase[index].keyword, using: secretKey) - let response = try #require(servers[shardID]).computeResponse(to: query, using: evaluationKey) - if Scheme.self != NoOpScheme.self { + let response = try await #require(servers[shardID]).computeResponse(to: query, using: evaluationKey) + if PirUtil.Scheme.self != NoOpScheme.self { #expect(!response.isTransparent()) } let result = try client.decrypt( @@ -388,8 +390,8 @@ extension PirTestUtils { let shardID = noKey.shardID(shardCount: shardCount) let client = try #require(clients[shardID]) let query = try client.generateQuery(at: noKey, using: secretKey) - let response = try #require(servers[shardID]).computeResponse(to: query, using: evaluationKey) - if Scheme.self != NoOpScheme.self { + let response = try await #require(servers[shardID]).computeResponse(to: query, using: evaluationKey) + if PirUtil.Scheme.self != NoOpScheme.self { #expect(!response.isTransparent()) } let result = try client.decrypt(response: response, at: noKey, using: secretKey) @@ -398,14 +400,14 @@ extension PirTestUtils { /// Test limiting entries per response. @inlinable - public static func limitEntriesPerResponse(_: Scheme.Type) throws { + public static func limitEntriesPerResponse(_: Scheme.Type) async throws { // swiftlint:disable nesting - typealias PirClient = MulPirClient - typealias PirServer = MulPirServer + typealias PirClient = MulPirClient> + typealias PirServer = MulPirServer> // swiftlint:enable nesting let rlweParams = PredefinedRlweParameters.n_4096_logq_27_28_28_logt_5 - let context: Context = try Context(encryptionParameters: .init(from: rlweParams)) + let context = try PirServer.Scheme.Context(encryptionParameters: .init(from: rlweParams)) let numberOfEntriesPerResponse = 8 let hashFunctionCount = 2 var testRng = TestRng() @@ -421,7 +423,7 @@ extension PirTestUtils { unevenDimensions: true, keyCompression: .noCompression, useMaxSerializedBucketSize: true) - let processed = try KeywordPirServer.process( + let processed = try await KeywordPirServer.process( database: testDatabase, config: config, with: context) @@ -436,7 +438,7 @@ extension PirTestUtils { let evaluationKey = try client.generateEvaluationKey(using: secretKey) let randomKeyValuePair = try #require(testDatabase.randomElement()) let query = try client.generateQuery(at: randomKeyValuePair.keyword, using: secretKey) - let response = try server.computeResponse(to: query, using: evaluationKey) + let response = try await server.computeResponse(to: query, using: evaluationKey) let result = try client.decrypt(response: response, at: randomKeyValuePair.keyword, using: secretKey) #expect(result == randomKeyValuePair.value) let entriesFound = try client.countEntriesInResponse(response: response, using: secretKey) diff --git a/Sources/_TestUtilities/PirUtilities/MulPirTests.swift b/Sources/_TestUtilities/PirUtilities/MulPirTests.swift index 2006e47a..c4822bca 100644 --- a/Sources/_TestUtilities/PirUtilities/MulPirTests.swift +++ b/Sources/_TestUtilities/PirUtilities/MulPirTests.swift @@ -71,20 +71,20 @@ extension PirTestUtils { /// Tests query generation. @inlinable - public static func queryGenerationTest( - scheme _: Scheme.Type, - _ keyCompression: PirKeyCompressionStrategy) throws + public static func queryGenerationTest( + pirUtil _: PirUtil.Type, + _ keyCompression: PirKeyCompressionStrategy) async throws { let entryCount = 200 let entrySizeInBytes = 16 - let context: Context = try TestUtils.getTestContext() + let context: PirUtil.Scheme.Context = try TestUtils.getTestContext() let secretKey = try context.generateSecretKey() let parameter = try PirTestUtils.getTestParameter( - pir: MulPir.self, + pir: MulPir.self, with: context, entryCount: entryCount, entrySizeInBytes: entrySizeInBytes, keyCompression: keyCompression) - let client = MulPirClient(parameter: parameter, context: context) + let client = MulPirClient(parameter: parameter, context: context) let evaluationKey = try client.generateEvaluationKey(using: secretKey) for _ in 0..<3 { @@ -94,11 +94,11 @@ extension PirTestUtils { let queryIndices = Array(indices.prefix(batchSize)) let query = try client.generateQuery(at: queryIndices, using: secretKey) let outputCount = parameter.expandedQueryCount * batchSize - let expandedQuery: [Scheme.CanonicalCiphertext] = try PirUtil.expandCiphertexts( + let expandedQuery: [PirUtil.Scheme.CanonicalCiphertext] = try await PirUtil.expand(ciphertexts: query.ciphertexts, outputCount: outputCount, using: evaluationKey) - let decodedQuery: [[Scheme.Scalar]] = try expandedQuery.map { ciphertext in + let decodedQuery: [[PirUtil.Scheme.Scalar]] = try expandedQuery.map { ciphertext in try ciphertext.decrypt(using: secretKey).decode(format: .coefficient) } @@ -127,8 +127,8 @@ extension PirTestUtils { /// Tests client computing query coordinates. @inlinable - public static func computeCoordinates(scheme _: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + public static func computeCoordinates(pirUtil _: PirUtil.Type) throws { + let context: PirUtil.Scheme.Context = try TestUtils.getTestContext() let evalKeyConfig = EvaluationKeyConfig() // two dimensional case do { @@ -138,7 +138,7 @@ extension PirTestUtils { dimensions: [10, 10], batchSize: 1, evaluationKeyConfig: evalKeyConfig) - let client = MulPirClient(parameter: parameter, context: context) + let client = MulPirClient(parameter: parameter, context: context) let vectors = [ (0, [0, 0]), @@ -163,7 +163,7 @@ extension PirTestUtils { dimensions: [5, 3, 2], batchSize: 1, evaluationKeyConfig: evalKeyConfig) - let client = MulPirClient(parameter: parameter, context: context) + let client = MulPirClient(parameter: parameter, context: context) let vectors = [ (0, [0, 0, 0]), diff --git a/Sources/_TestUtilities/PirUtilities/PirTestUtils.swift b/Sources/_TestUtilities/PirUtilities/PirTestUtils.swift index 33e65ce2..b1692a8f 100644 --- a/Sources/_TestUtilities/PirUtilities/PirTestUtils.swift +++ b/Sources/_TestUtilities/PirUtilities/PirTestUtils.swift @@ -20,7 +20,7 @@ public enum PirTestUtils { /// Creates test parameters. public static func getTestParameter( pir _: Pir.Type, - with context: Context, + with context: Pir.Scheme.Context, entryCount: Int, entrySizeInBytes: Int, keyCompression: PirKeyCompressionStrategy, diff --git a/Sources/_TestUtilities/PirUtilities/SymmetricPirTests.swift b/Sources/_TestUtilities/PirUtilities/SymmetricPirTests.swift index ea709672..87905caf 100644 --- a/Sources/_TestUtilities/PirUtilities/SymmetricPirTests.swift +++ b/Sources/_TestUtilities/PirUtilities/SymmetricPirTests.swift @@ -30,10 +30,10 @@ extension PirTestUtils { /// Tests symmetric PIR round trip. @inlinable - public static func roundTrip(_: Scheme.Type) throws { + public static func roundTrip(_: Scheme.Type) async throws { // swiftlint:disable nesting - typealias PirClient = MulPirClient - typealias PirServer = MulPirServer + typealias PirClient = MulPirClient> + typealias PirServer = MulPirServer> // swiftlint:enable nesting let symmetricPirConfig = try Self.generateSymmetricPirConfig() @@ -45,16 +45,16 @@ extension PirTestUtils { symmetricPirClientConfig: symmetricPirConfig.clientConfig()) let encryptionParameters: EncryptionParameters = try TestUtils.getTestEncryptionParameters() - let context: Context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let valueSize = context.bytesPerPlaintext / 2 let plainDatabase = PirTestUtils.randomKeywordPirDatabase(rowCount: 100, valueSize: valueSize) let encryptedDatabase = try KeywordDatabase.symmetricPIRProcess( database: plainDatabase, config: symmetricPirConfig) - let processed = try KeywordPirServer.process(database: encryptedDatabase, - config: keywordConfig, - with: context, - symmetricPirConfig: symmetricPirConfig) + let processed = try await KeywordPirServer.process(database: encryptedDatabase, + config: keywordConfig, + with: context, + symmetricPirConfig: symmetricPirConfig) let server = try KeywordPirServer( context: context, processed: processed) @@ -75,7 +75,7 @@ extension PirTestUtils { let parsedOprfOutput = try oprfClient.parse(oprfResponse: oprfResponse, with: oprfQueryContext) // Keyword PIR let query = try client.generateQuery(at: parsedOprfOutput.obliviousKeyword, using: secretKey) - let response = try server.computeResponse(to: query, using: evaluationKey) + let response = try await server.computeResponse(to: query, using: evaluationKey) #expect(!response.isTransparent()) let result = try client.decrypt( response: response, diff --git a/Sources/_TestUtilities/PnnsUtilities/CiphertextMatrixTests.swift b/Sources/_TestUtilities/PnnsUtilities/CiphertextMatrixTests.swift index 4135ada9..da75ac00 100644 --- a/Sources/_TestUtilities/PnnsUtilities/CiphertextMatrixTests.swift +++ b/Sources/_TestUtilities/PnnsUtilities/CiphertextMatrixTests.swift @@ -39,11 +39,11 @@ extension PrivateNearestNeighborSearchUtil { public enum CiphertextMatrixTests { /// Testing encryption/decryption round-trip. @inlinable - public static func encryptDecryptRoundTrip(for _: Scheme.Type) throws { + public static func encryptDecryptRoundTrip(for _: Scheme.Type) async throws { let rlweParams = PredefinedRlweParameters.insecure_n_8_logq_5x18_logt_5 let encryptionParameters = try EncryptionParameters(from: rlweParams) #expect(encryptionParameters.supportsSimdEncoding) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let dimensions = try MatrixDimensions(rowCount: 10, columnCount: 4) let encodeValues: [[Scheme.Scalar]] = increasingData( dimensions: dimensions, @@ -60,7 +60,7 @@ extension PrivateNearestNeighborSearchUtil { // modSwitchDownToSingle do { - try ciphertextMatrix.modSwitchDownToSingle() + try await ciphertextMatrix.modSwitchDownToSingle() let plaintextMatrixRoundTrip = try ciphertextMatrix.decrypt(using: secretKey) #expect(plaintextMatrixRoundTrip == plaintextMatrix) } @@ -72,7 +72,7 @@ extension PrivateNearestNeighborSearchUtil { let rlweParams = PredefinedRlweParameters.insecure_n_8_logq_5x18_logt_5 let encryptionParameters = try EncryptionParameters(from: rlweParams) #expect(encryptionParameters.supportsSimdEncoding) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let dimensions = try MatrixDimensions(rowCount: 10, columnCount: 4) let encodeValues: [[Scheme.Scalar]] = increasingData( dimensions: dimensions, @@ -92,7 +92,7 @@ extension PrivateNearestNeighborSearchUtil { /// Testing `extractDenseRow`. @inlinable - public static func extractDenseRow(for _: Scheme.Type) throws { + public static func extractDenseRow(for _: Scheme.Type) async throws { let degree = 16 let plaintextModulus = try Scheme.Scalar.generatePrimes( significantBitCounts: [9], @@ -109,7 +109,7 @@ extension PrivateNearestNeighborSearchUtil { errorStdDev: .stdDev32, securityLevel: .unchecked) #expect(encryptionParameters.supportsSimdEncoding) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) for rowCount in 1..<(2 * degree) { for columnCount in 1..(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let vectorDimension = 32 let queryDimensions = try MatrixDimensions(rowCount: 1, columnCount: vectorDimension) @@ -169,12 +169,12 @@ extension PrivateNearestNeighborSearchUtil { /// Testing client-server round-trip functionality. @inlinable - public static func clientServer(for _: Scheme.Type) throws { + public static func clientServer(for _: Scheme.Type) async throws { func runSingleTest( encryptionParameters: EncryptionParameters, dimensions: MatrixDimensions, plaintextModuli: [Scheme.Scalar], - queryCount: Int) throws + queryCount: Int) async throws { let vectorDimension = dimensions.columnCount let scalingFactor = ClientConfig.maxScalingFactor( @@ -201,7 +201,7 @@ extension PrivateNearestNeighborSearchUtil { let database = PrivateNearestNeighborSearchUtil.getDatabaseForTesting(config: DatabaseConfig( rowCount: dimensions.rowCount, vectorDimension: dimensions.columnCount)) - let processed = try database.process(config: serverConfig) + let processed = try await database.process(config: serverConfig) let client = try Client(config: clientConfig, contexts: processed.contexts) let server = try Server(database: processed) @@ -212,7 +212,7 @@ extension PrivateNearestNeighborSearchUtil { let query = try client.generateQuery(for: queryVectors, using: secretKey) let evaluationKey = try client.generateEvaluationKey(using: secretKey) - let response = try server.computeResponse(to: query, using: evaluationKey) + let response = try await server.computeResponse(to: query, using: evaluationKey) let noiseBudget = try response.noiseBudget(using: secretKey, variableTime: true) #expect(noiseBudget > 0) let decrypted = try client.decrypt(response: response, using: secretKey) @@ -253,7 +253,7 @@ extension PrivateNearestNeighborSearchUtil { for rowCount in [degree / 2, degree, degree + 1, 3 * degree] { for dimensions in try [MatrixDimensions(rowCount: rowCount, columnCount: 16)] { for plaintextModuliCount in 1...maxPlaintextModuliCount { - try runSingleTest( + try await runSingleTest( encryptionParameters: encryptionParameters, dimensions: dimensions, plaintextModuli: Array(plaintextModuli.prefix(plaintextModuliCount)), diff --git a/Sources/_TestUtilities/PnnsUtilities/DatabaseTests.swift b/Sources/_TestUtilities/PnnsUtilities/DatabaseTests.swift index ed37eee0..dcb72e66 100644 --- a/Sources/_TestUtilities/PnnsUtilities/DatabaseTests.swift +++ b/Sources/_TestUtilities/PnnsUtilities/DatabaseTests.swift @@ -21,7 +21,7 @@ extension PrivateNearestNeighborSearchUtil { public enum DatabaseTests { /// Test serialization. @inlinable - public static func serializedProcessedDatabase(for _: Scheme.Type) throws { + public static func serializedProcessedDatabase(for _: Scheme.Type) async throws { let encryptionParameters = try EncryptionParameters(from: .insecure_n_8_logq_5x18_logt_5) let vectorDimension = 4 @@ -49,7 +49,7 @@ extension PrivateNearestNeighborSearchUtil { .diagonal(babyStepGiantStep: BabyStepGiantStep(vectorDimension: vectorDimension)) let serverConfig = ServerConfig(clientConfig: clientConfig, databasePacking: databasePacking) - let processed: ProcessedDatabase = try database.process(config: serverConfig) + let processed = try await database.process(config: serverConfig) let serialized = try processed.serialize() let deserialized = try ProcessedDatabase(from: serialized, contexts: processed.contexts) #expect(deserialized == processed) diff --git a/Sources/_TestUtilities/PnnsUtilities/MatrixMultiplicationTests.swift b/Sources/_TestUtilities/PnnsUtilities/MatrixMultiplicationTests.swift index aaf846a4..63192700 100644 --- a/Sources/_TestUtilities/PnnsUtilities/MatrixMultiplicationTests.swift +++ b/Sources/_TestUtilities/PnnsUtilities/MatrixMultiplicationTests.swift @@ -47,15 +47,15 @@ extension PrivateNearestNeighborSearchUtil { public enum MatrixMultiplicationTests { /// Testing matrix-vector multiplication. @inlinable - public static func mulVector(for _: Scheme.Type) throws { + public static func mulVector(for _: Scheme.Type) async throws { func checkProduct( _: Scheme.Type, _ plaintextRows: [[Scheme.Scalar]], _ plaintextMatrixDimensions: MatrixDimensions, - _ queryValues: [Scheme.Scalar]) throws + _ queryValues: [Scheme.Scalar]) async throws { let encryptionParameters = try EncryptionParameters(from: .n_4096_logq_27_28_28_logt_16) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let secretKey = try context.generateSecretKey() var expected: [Scheme.Scalar] = try plaintextRows.mul( @@ -68,7 +68,7 @@ extension PrivateNearestNeighborSearchUtil { } let babyStepGiantStep = BabyStepGiantStep(vectorDimension: queryValues.count) - let plaintextMatrix = try PlaintextMatrix( + let plaintextMatrix = try PlaintextMatrix( context: context, dimensions: plaintextMatrixDimensions, packing: .diagonal(babyStepGiantStep: babyStepGiantStep), @@ -95,7 +95,7 @@ extension PrivateNearestNeighborSearchUtil { packing: .denseRow, values: queryValues).encrypt(using: secretKey) - let dotProduct = try plaintextMatrix.mulTranspose(vector: ciphertextVector, using: evaluationKey) + let dotProduct = try await plaintextMatrix.mulTranspose(vector: ciphertextVector, using: evaluationKey) let expectedCiphertextsCount = plaintextMatrixDimensions.rowCount.dividingCeil( encryptionParameters.polyDegree, variableTime: true) @@ -116,33 +116,34 @@ extension PrivateNearestNeighborSearchUtil { } var dimensions = try MatrixDimensions(rowCount: 6, columnCount: 6) var queryValues: [Scheme.Scalar] = Array(repeating: 2, count: 6) - try checkProduct(Scheme.self, values, dimensions, queryValues) + try await checkProduct(Scheme.self, values, dimensions, queryValues) // Tall - 64x16 dimensions = try MatrixDimensions(rowCount: 64, columnCount: 16) values = increasingData(dimensions: dimensions, modulus: Scheme.Scalar(17)) queryValues = Array(1...16) - try checkProduct(Scheme.self, values, dimensions, queryValues) + try await checkProduct(Scheme.self, values, dimensions, queryValues) // Broad - 16x64 dimensions = try MatrixDimensions(rowCount: 16, columnCount: 64) values = increasingData(dimensions: dimensions, modulus: Scheme.Scalar(70)) queryValues = Array(1...64) queryValues.reverse() - try checkProduct(Scheme.self, values, dimensions, queryValues) + try await checkProduct(Scheme.self, values, dimensions, queryValues) // Multiple result ciphertexts. 10240x4 dimensions = try MatrixDimensions(rowCount: 10240, columnCount: 4) values = increasingData(dimensions: dimensions, modulus: Scheme.Scalar(17)) queryValues = Array(1...4) - try checkProduct(Scheme.self, values, dimensions, queryValues) + try await checkProduct(Scheme.self, values, dimensions, queryValues) } @inlinable package static func matrixMulRunner( - context: Context, + scheme _: Scheme.Type, + context: Scheme.Context, plaintextValues: [[Scheme.Scalar]], - queryValues: [[Scheme.Scalar]]) throws + queryValues: [[Scheme.Scalar]]) async throws { let encryptionParameters = context.encryptionParameters let secretKey = try context.generateSecretKey() @@ -159,7 +160,7 @@ extension PrivateNearestNeighborSearchUtil { let plaintextDimensions = try MatrixDimensions( rowCount: plaintextValues.count, columnCount: plaintextValues[0].count) - let plaintextMatrix = try PlaintextMatrix( + let plaintextMatrix = try PlaintextMatrix( context: context, dimensions: plaintextDimensions, packing: .diagonal(babyStepGiantStep: babyStepGiantStep), @@ -171,7 +172,7 @@ extension PrivateNearestNeighborSearchUtil { encryptionParameters: encryptionParameters, scheme: Scheme.self) let evaluationKey = try context.generateEvaluationKey(config: evaluationKeyConfig, using: secretKey) - let decryptedValues: [Scheme.Scalar] = try plaintextMatrix.mulTranspose( + let decryptedValues: [Scheme.Scalar] = try await plaintextMatrix.mulTranspose( matrix: ciphertextMatrix, using: evaluationKey) .decrypt(using: secretKey).unpack() @@ -181,12 +182,12 @@ extension PrivateNearestNeighborSearchUtil { /// Testing matrix multiplication for large dimensions. @inlinable - public static func matrixMulLargeDimensions(for _: Scheme.Type) throws { + public static func matrixMulLargeDimensions(for _: Scheme.Type) async throws { func testOnRandomData( plaintextRows: Int, plaintextCols: Int, ciphertextRows: Int, - context: Context) throws + context: Scheme.Context) async throws { let plaintextMatrixDimensions = try MatrixDimensions( rowCount: plaintextRows, @@ -200,7 +201,8 @@ extension PrivateNearestNeighborSearchUtil { let queryValues: [[Scheme.Scalar]] = randomData( dimensions: ciphertextMatrixDimensions, modulus: context.encryptionParameters.plaintextModulus) - try Self.matrixMulRunner( + try await Self.matrixMulRunner( + scheme: Scheme.self, context: context, plaintextValues: plaintextValues, queryValues: queryValues) @@ -221,76 +223,96 @@ extension PrivateNearestNeighborSearchUtil { errorStdDev: ErrorStdDev.stdDev32, securityLevel: SecurityLevel.unchecked) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) do { // Tall - try testOnRandomData(plaintextRows: degree / 2, plaintextCols: 128, ciphertextRows: 3, context: context) - try testOnRandomData(plaintextRows: degree / 2, plaintextCols: 384, ciphertextRows: 3, context: context) - try testOnRandomData( + try await testOnRandomData( + plaintextRows: degree / 2, + plaintextCols: 128, + ciphertextRows: 3, + context: context) + try await testOnRandomData( + plaintextRows: degree / 2, + plaintextCols: 384, + ciphertextRows: 3, + context: context) + try await testOnRandomData( plaintextRows: 3 * degree / 4, plaintextCols: 128, ciphertextRows: 3, context: context) - try testOnRandomData(plaintextRows: degree, plaintextCols: 128, ciphertextRows: 1, context: context) - try testOnRandomData(plaintextRows: 2 * degree, plaintextCols: 128, ciphertextRows: 2, context: context) - try testOnRandomData(plaintextRows: 3 * degree, plaintextCols: 128, ciphertextRows: 3, context: context) + try await testOnRandomData( + plaintextRows: degree, + plaintextCols: 128, + ciphertextRows: 1, + context: context) + try await testOnRandomData( + plaintextRows: 2 * degree, + plaintextCols: 128, + ciphertextRows: 2, + context: context) + try await testOnRandomData( + plaintextRows: 3 * degree, + plaintextCols: 128, + ciphertextRows: 3, + context: context) } do { // Short, power-of-two ncols - try testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 1, context: context) - try testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 2, context: context) - try testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 16, context: context) - try testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 32, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 1, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 2, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 16, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 32, context: context) } do { // Short, non-power-of-two ncols - try testOnRandomData(plaintextRows: 160, plaintextCols: 384, ciphertextRows: 1, context: context) - try testOnRandomData(plaintextRows: 160, plaintextCols: 384, ciphertextRows: 2, context: context) - try testOnRandomData(plaintextRows: 160, plaintextCols: 384, ciphertextRows: 16, context: context) - try testOnRandomData(plaintextRows: 160, plaintextCols: 384, ciphertextRows: 32, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 384, ciphertextRows: 1, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 384, ciphertextRows: 2, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 384, ciphertextRows: 16, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 384, ciphertextRows: 32, context: context) } do { // Short, power-of-two ncols - try testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 1, context: context) - try testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 2, context: context) - try testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 16, context: context) - try testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 32, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 1, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 2, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 16, context: context) + try await testOnRandomData(plaintextRows: 160, plaintextCols: 128, ciphertextRows: 32, context: context) } do { // Wide columns var columnCount = degree / 4 - try testOnRandomData( + try await testOnRandomData( plaintextRows: 512, plaintextCols: columnCount, ciphertextRows: 1, context: context) - try testOnRandomData( + try await testOnRandomData( plaintextRows: 512, plaintextCols: columnCount, ciphertextRows: 2, context: context) - try testOnRandomData( + try await testOnRandomData( plaintextRows: 512, plaintextCols: columnCount, ciphertextRows: 5, context: context) columnCount = degree / 2 - try testOnRandomData( + try await testOnRandomData( plaintextRows: 512, plaintextCols: columnCount, ciphertextRows: 1, context: context) - try testOnRandomData( + try await testOnRandomData( plaintextRows: 512, plaintextCols: columnCount, ciphertextRows: 2, context: context) - try testOnRandomData( + try await testOnRandomData( plaintextRows: 512, plaintextCols: columnCount, ciphertextRows: 5, @@ -300,11 +322,11 @@ extension PrivateNearestNeighborSearchUtil { /// Testing matrix multiplication for small dimensions @inlinable - public static func matrixMulSmallDimensions(for _: Scheme.Type) throws { + public static func matrixMulSmallDimensions(for _: Scheme.Type) async throws { func testOnIncreasingData( plaintextDimensions: MatrixDimensions, queryDimensions: MatrixDimensions, - context: Context) throws + context: Scheme.Context) async throws { let plaintextModulus = context.encryptionParameters.plaintextModulus let plaintextValues: [[Scheme.Scalar]] = increasingData( @@ -313,19 +335,20 @@ extension PrivateNearestNeighborSearchUtil { let queryValues: [[Scheme.Scalar]] = increasingData( dimensions: queryDimensions, modulus: plaintextModulus) - try Self.matrixMulRunner( + try await Self.matrixMulRunner( + scheme: Scheme.self, context: context, plaintextValues: plaintextValues, queryValues: queryValues) } let encryptionParameters = try EncryptionParameters(from: .insecure_n_8_logq_5x18_logt_5) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) do { // 8x4x2 let plaintextDimensions = try MatrixDimensions(rowCount: 8, columnCount: 4) let queryDimensions = try MatrixDimensions(rowCount: 2, columnCount: 4) - try testOnIncreasingData( + try await testOnIncreasingData( plaintextDimensions: plaintextDimensions, queryDimensions: queryDimensions, context: context) @@ -334,7 +357,7 @@ extension PrivateNearestNeighborSearchUtil { // 7x2x4 let plaintextDimensions = try MatrixDimensions(rowCount: 7, columnCount: 2) let queryDimensions = try MatrixDimensions(rowCount: 4, columnCount: 2) - try testOnIncreasingData( + try await testOnIncreasingData( plaintextDimensions: plaintextDimensions, queryDimensions: queryDimensions, context: context) @@ -343,7 +366,7 @@ extension PrivateNearestNeighborSearchUtil { // 6x1x2 let plaintextDimensions = try MatrixDimensions(rowCount: 6, columnCount: 1) let queryDimensions = try MatrixDimensions(rowCount: 2, columnCount: 1) - try testOnIncreasingData( + try await testOnIncreasingData( plaintextDimensions: plaintextDimensions, queryDimensions: queryDimensions, context: context) @@ -353,7 +376,7 @@ extension PrivateNearestNeighborSearchUtil { // Non-power of 2 ncols let plaintextDimensions = try MatrixDimensions(rowCount: 5, columnCount: 3) let queryDimensions = try MatrixDimensions(rowCount: 2, columnCount: 3) - try testOnIncreasingData( + try await testOnIncreasingData( plaintextDimensions: plaintextDimensions, queryDimensions: queryDimensions, context: context) @@ -362,7 +385,7 @@ extension PrivateNearestNeighborSearchUtil { // Tall, plaintext rows in [N/4, N/2] let plaintextDimensions = try MatrixDimensions(rowCount: 200, columnCount: 4) let queryDimensions = try MatrixDimensions(rowCount: 5, columnCount: 4) - try testOnIncreasingData( + try await testOnIncreasingData( plaintextDimensions: plaintextDimensions, queryDimensions: queryDimensions, context: context) @@ -371,7 +394,7 @@ extension PrivateNearestNeighborSearchUtil { // Tall, plaintext rows > N let plaintextDimensions = try MatrixDimensions(rowCount: 10, columnCount: 4) let queryDimensions = try MatrixDimensions(rowCount: 5, columnCount: 4) - try testOnIncreasingData( + try await testOnIncreasingData( plaintextDimensions: plaintextDimensions, queryDimensions: queryDimensions, context: context) diff --git a/Sources/_TestUtilities/PnnsUtilities/PlaintextMatrixTests.swift b/Sources/_TestUtilities/PnnsUtilities/PlaintextMatrixTests.swift index 4d1df8f3..33cc5ef6 100644 --- a/Sources/_TestUtilities/PnnsUtilities/PlaintextMatrixTests.swift +++ b/Sources/_TestUtilities/PnnsUtilities/PlaintextMatrixTests.swift @@ -40,7 +40,7 @@ extension PrivateNearestNeighborSearchUtil { } let dims = try MatrixDimensions(rowCount: encryptionParameters.polyDegree, columnCount: 2) let packing = MatrixPacking.denseRow - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let values = TestUtils.getRandomPlaintextData( count: encryptionParameters.polyDegree, in: 0..(from: diffRlweParams) - let diffContext = try Context(encryptionParameters: diffEncryptionParams) + let diffContext = try Scheme.Context(encryptionParameters: diffEncryptionParams) let diffValues = TestUtils.getRandomPlaintextData( count: diffEncryptionParams.polyDegree, in: 0..(for _: Scheme.Type) throws { let rlweParams = PredefinedRlweParameters.insecure_n_8_logq_5x18_logt_5 let encryptionParameters = try EncryptionParameters(from: rlweParams) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let rowCount = encryptionParameters.polyDegree let columnCount = 2 let values = TestUtils.getRandomPlaintextData( @@ -121,7 +121,8 @@ extension PrivateNearestNeighborSearchUtil { @inlinable static func runPlaintextMatrixInitTest( - context: Context, + scheme _: Scheme.Type, + context: Scheme.Context, dimensions: MatrixDimensions, packing: MatrixPacking, expected: [[Int]]) throws @@ -285,10 +286,11 @@ extension PrivateNearestNeighborSearchUtil { nttDegree: 8), errorStdDev: .stdDev32, securityLevel: .unchecked) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) for ((rowCount, columnCount), expected) in kats { let dimensions = try MatrixDimensions((rowCount, columnCount)) try Self.runPlaintextMatrixInitTest( + scheme: Scheme.self, context: context, dimensions: dimensions, packing: .denseColumn, expected: expected) @@ -364,10 +366,11 @@ extension PrivateNearestNeighborSearchUtil { let rlweParams = PredefinedRlweParameters.insecure_n_8_logq_5x18_logt_5 let encryptionParameters = try EncryptionParameters(from: rlweParams) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) for ((rowCount, columnCount), expected) in kats { let dimensions = try MatrixDimensions((rowCount, columnCount)) try Self.runPlaintextMatrixInitTest( + scheme: Scheme.self, context: context, dimensions: dimensions, packing: .denseRow, @@ -462,11 +465,12 @@ extension PrivateNearestNeighborSearchUtil { nttDegree: 8), errorStdDev: ErrorStdDev.stdDev32, securityLevel: SecurityLevel.unchecked) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) for ((rowCount, columnCount), expected) in kats { let dimensions = try MatrixDimensions((rowCount, columnCount)) let bsgs = BabyStepGiantStep(vectorDimension: dimensions.columnCount.nextPowerOfTwo) try Self.runPlaintextMatrixInitTest( + scheme: Scheme.self, context: context, dimensions: dimensions, packing: .diagonal(babyStepGiantStep: bsgs), @@ -486,7 +490,7 @@ extension PrivateNearestNeighborSearchUtil { nttDegree: 16), errorStdDev: ErrorStdDev.stdDev32, securityLevel: SecurityLevel.unchecked) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let dimensions = try MatrixDimensions(rowCount: 4, columnCount: 5) let bsgs = BabyStepGiantStep(vectorDimension: dimensions.columnCount) @@ -525,7 +529,7 @@ extension PrivateNearestNeighborSearchUtil { let rlweParams = PredefinedRlweParameters.insecure_n_8_logq_5x18_logt_5 let encryptionParameters = try EncryptionParameters(from: rlweParams) #expect(encryptionParameters.supportsSimdEncoding) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let dimensions = try MatrixDimensions(rowCount: 10, columnCount: 4) let encodeValues: [[Scheme.Scalar]] = increasingData( dimensions: dimensions, diff --git a/Sources/_TestUtilities/TestUtilities.swift b/Sources/_TestUtilities/TestUtilities.swift index c301754e..7f694d6d 100644 --- a/Sources/_TestUtilities/TestUtilities.swift +++ b/Sources/_TestUtilities/TestUtilities.swift @@ -26,7 +26,6 @@ import Testing /// - absoluteTolerance: An optional absolute tolerance to enforce. /// - Returns: true if the expressions are close to each other extension BinaryFloatingPoint { - @inlinable package func isClose(to value: Self, relativeTolerance: Self = Self(1e-5), absoluteTolerance: Self = Self(1e-8)) -> Bool @@ -92,12 +91,12 @@ extension [UInt8] { } } -@usableFromInline -package enum TestUtils { +/// A collection of constants used in tests. +public enum TestUtils { /// A polynomial degree suitable for testing. - @usableFromInline package static let testPolyDegree = 16 + public static let testPolyDegree = 16 /// A plaintext modulus suitable for testing. - @usableFromInline package static let testPlaintextModulus = 1153 + public static let testPlaintextModulus = 1153 } extension TestUtils { @@ -154,8 +153,10 @@ extension TestUtils { return Double(binCount) * probabilityOfExactlyCountBallsInFirstBin } - @inlinable - package static func getRandomPlaintextData(count: Int, in range: Range) -> [T] { + /// Generates random array for plaintext encoding. + public static func getRandomPlaintextData(count: Int, + in range: Range) -> [T] + { (0..() throws -> Context { - try Context(encryptionParameters: getTestEncryptionParameters()) + /// Returns a `HeContext` initialized with the parameters used for testing. + public static func getTestContext() throws -> Context { + try Context(encryptionParameters: getTestEncryptionParameters()) } } diff --git a/Tests/ApplicationProtobufTests/PirConversionTests.swift b/Tests/ApplicationProtobufTests/PirConversionTests.swift index cca145f3..bf6f18ec 100644 --- a/Tests/ApplicationProtobufTests/PirConversionTests.swift +++ b/Tests/ApplicationProtobufTests/PirConversionTests.swift @@ -41,7 +41,7 @@ struct PirConversionTests { } @Test - func processedDatabaseWithParameters() throws { + func processedDatabaseWithParameters() async throws { let rows = (0..<10).map { KeywordValuePair(keyword: Array(String($0).utf8), value: Array(String($0).utf8)) } let context: Context> = try .init(encryptionParameters: .init(from: .n_4096_logq_27_28_28_logt_13)) let config = try KeywordPirConfig( @@ -49,7 +49,7 @@ struct PirConversionTests { cuckooTableConfig: .defaultKeywordPir(maxSerializedBucketSize: context.bytesPerPlaintext), unevenDimensions: true, keyCompression: .noCompression) - let processedDatabaseWithParameters = try KeywordPirServer>>.process( + let processedDatabaseWithParameters = try await KeywordPirServer>>>.process( database: rows, config: config, with: context) diff --git a/Tests/ApplicationProtobufTests/PnnsConversionTests.swift b/Tests/ApplicationProtobufTests/PnnsConversionTests.swift index 008d89ac..718cb7d8 100644 --- a/Tests/ApplicationProtobufTests/PnnsConversionTests.swift +++ b/Tests/ApplicationProtobufTests/PnnsConversionTests.swift @@ -103,27 +103,27 @@ struct PnnsConversionTests { func serializedPlaintextMatrix() throws { func runTest(_: Scheme.Type) throws { let encryptionParameters = try EncryptionParameters(from: .insecure_n_8_logq_5x18_logt_5) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let dimensions = try MatrixDimensions(rowCount: 5, columnCount: 4) let scalars: [[Scheme.Scalar]] = increasingData( dimensions: dimensions, modulus: encryptionParameters.plaintextModulus) - let plaintextMatrix = try PlaintextMatrix( + let plaintextMatrix = try PlaintextMatrix( context: context, dimensions: dimensions, packing: .denseColumn, values: scalars.flatMap(\.self)) let serialized = try plaintextMatrix.serialize() #expect(try serialized.proto().native() == serialized) - let deserialized = try PlaintextMatrix(deserialize: serialized, context: context) + let deserialized = try PlaintextMatrix(deserialize: serialized, context: context) #expect(deserialized == plaintextMatrix) for moduliCount in 1..( deserialize: serialized, context: context, moduliCount: moduliCount) @@ -136,17 +136,17 @@ struct PnnsConversionTests { } @Test - func serializedCiphertextMatrix() throws { - func runTest(_: Scheme.Type) throws { + func serializedCiphertextMatrix() async throws { + func runTest(_: Scheme.Type) async throws { let encryptionParameters = try EncryptionParameters(from: .insecure_n_8_logq_5x18_logt_5) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let secretKey = try context.generateSecretKey() let dimensions = try MatrixDimensions(rowCount: 5, columnCount: 4) let scalars: [[Scheme.Scalar]] = increasingData( dimensions: dimensions, modulus: encryptionParameters.plaintextModulus) - let plaintextMatrix = try PlaintextMatrix( + let plaintextMatrix = try PlaintextMatrix( context: context, dimensions: dimensions, packing: .denseColumn, @@ -161,7 +161,7 @@ struct PnnsConversionTests { // Check Evaluation format do { let ciphertextMatrix = try plaintextMatrix.encrypt(using: secretKey) - let evalCiphertextMatrix = try ciphertextMatrix.convertToEvalFormat() + let evalCiphertextMatrix = try await ciphertextMatrix.convertToEvalFormat() let serialized = try evalCiphertextMatrix.serialize() #expect(try serialized.proto().native() == serialized) let deserialized = try CiphertextMatrix( @@ -172,7 +172,7 @@ struct PnnsConversionTests { // Check serializeForDecryption do { var ciphertextMatrix = try plaintextMatrix.encrypt(using: secretKey) - try ciphertextMatrix.modSwitchDownToSingle() + try await ciphertextMatrix.modSwitchDownToSingle() let serializedForDecryption = try ciphertextMatrix.serialize(forDecryption: true) let serializedForDecryptionSize = try serializedForDecryption.proto().serializedData().count @@ -189,22 +189,22 @@ struct PnnsConversionTests { } } - try runTest(Bfv.self) - try runTest(Bfv.self) + try await runTest(Bfv.self) + try await runTest(Bfv.self) } @Test func query() throws { func runTest(_: Scheme.Type) throws { let encryptionParameters = try EncryptionParameters(from: .insecure_n_8_logq_5x18_logt_5) - let context = try Context(encryptionParameters: encryptionParameters) + let context = try Scheme.Context(encryptionParameters: encryptionParameters) let secretKey = try context.generateSecretKey() let dimensions = try MatrixDimensions(rowCount: 5, columnCount: 4) let scalars: [[Scheme.Scalar]] = increasingData( dimensions: dimensions, modulus: encryptionParameters.plaintextModulus) - let plaintextMatrix = try PlaintextMatrix( + let plaintextMatrix = try PlaintextMatrix( context: context, dimensions: dimensions, packing: .denseColumn, @@ -214,7 +214,7 @@ struct PnnsConversionTests { } let query = Query(ciphertextMatrices: ciphertextMatrices) - let roundtrip = try query.proto().native(context: context) + let roundtrip = try query.proto().native(context: context) as Query #expect(roundtrip == query) } try runTest(Bfv.self) @@ -222,8 +222,8 @@ struct PnnsConversionTests { } @Test - func serializedProcessedDatabase() throws { - func runTest(_: Scheme.Type) throws { + func serializedProcessedDatabase() async throws { + func runTest(_: Scheme.Type) async throws { let encryptionParameters = try EncryptionParameters(from: .insecure_n_8_logq_5x18_logt_5) let vectorDimension = 4 @@ -256,11 +256,11 @@ struct PnnsConversionTests { .diagonal( babyStepGiantStep: BabyStepGiantStep(vectorDimension: vectorDimension))) - let processed = try database.process(config: serverConfig) + let processed = try await database.process(config: serverConfig) let serialized = try processed.serialize() #expect(try serialized.proto().native() == serialized) } - try runTest(Bfv.self) - try runTest(Bfv.self) + try await runTest(Bfv.self) + try await runTest(Bfv.self) } } diff --git a/Tests/HomomorphicEncryptionProtobufTests/ConversionTests.swift b/Tests/HomomorphicEncryptionProtobufTests/ConversionTests.swift index ee73bba8..04695100 100644 --- a/Tests/HomomorphicEncryptionProtobufTests/ConversionTests.swift +++ b/Tests/HomomorphicEncryptionProtobufTests/ConversionTests.swift @@ -56,7 +56,7 @@ struct ConversionTests { } func runTest(_: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0..(_: Scheme.Type, format: EncodeFormat) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0..(_: Scheme.Type, format: EncodeFormat) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0..(_: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let secretKey = try context.generateSecretKey() let proto = secretKey.serialize().proto() - let deserialized = try SecretKey(deserialize: proto.native(), context: context) + + let deserialized = try SecretKey(deserialize: proto.native(), context: context) #expect(deserialized == secretKey) } @@ -199,13 +200,14 @@ struct ConversionTests { @Test func galoisKey() throws { func runTest(_: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let secretKey = try context.generateSecretKey() let evaluationKey = try context.generateEvaluationKey( - config: EvaluationKeyConfig(galoisElements: [3, 5, 7]), using: secretKey) + config: EvaluationKeyConfig(galoisElements: [3, 5, 7]), + using: secretKey) let galoisKey = try #require(evaluationKey.galoisKey) let proto = galoisKey.serialize().proto() - let deserialized = try GaloisKey(deserialize: proto.native(), context: context) + let deserialized = try _GaloisKey(deserialize: proto.native(), context: context) #expect(deserialized == galoisKey) } @@ -217,14 +219,14 @@ struct ConversionTests { @Test func relinearizationKey() throws { func runTest(_: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let secretKey = try context.generateSecretKey() let evaluationKey = try context.generateEvaluationKey( - config: EvaluationKeyConfig( - hasRelinearizationKey: true), using: secretKey) + config: EvaluationKeyConfig(hasRelinearizationKey: true), + using: secretKey) let relinearizationKey = try #require(evaluationKey.relinearizationKey) let proto = relinearizationKey.serialize().proto() - let deserialized = try RelinearizationKey(deserialize: proto.native(), context: context) + let deserialized = try _RelinearizationKey(deserialize: proto.native(), context: context) #expect(deserialized == relinearizationKey) } @@ -236,14 +238,13 @@ struct ConversionTests { @Test func evaluationKey() throws { func runTest(_: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let secretKey = try context.generateSecretKey() let evaluationKey = try context.generateEvaluationKey( - config: EvaluationKeyConfig( - galoisElements: [3, 5, 7], - hasRelinearizationKey: true), using: secretKey) + config: EvaluationKeyConfig(galoisElements: [3, 5, 7], hasRelinearizationKey: true), + using: secretKey) let proto = evaluationKey.serialize().proto() - let deserialized = try EvaluationKey(deserialize: proto.native(), context: context) + let deserialized = try EvaluationKey(deserialize: proto.native(), context: context) #expect(deserialized == evaluationKey) } diff --git a/Tests/HomomorphicEncryptionTests/Array2dTests.swift b/Tests/HomomorphicEncryptionTests/Array2dTests.swift index 9cddd033..2d4c135f 100644 --- a/Tests/HomomorphicEncryptionTests/Array2dTests.swift +++ b/Tests/HomomorphicEncryptionTests/Array2dTests.swift @@ -156,4 +156,31 @@ struct Array2dTests { let roundtripArray = arrayPlus1.map { Int($0 - 1) } #expect(roundtripArray == array) } + + @Test + func withUnsafeData() { + let data = [Int](0..<32) + let array = Array2d(data: data, rowCount: 4, columnCount: 8) + array.withUnsafeData { dataPointer in + array.data.withUnsafeBufferPointer { expectedDataPointer in + #expect(dataPointer.baseAddress == expectedDataPointer.baseAddress) + } + } + } + + @Test + func withUnsafeMutableData() throws { + let data = [Int](0..<32) + var array = Array2d(data: data, rowCount: 4, columnCount: 8) + // For the comparison we need 'mutable' pointers of the same type. + // But, `withUnsafe*` methods need exclusive ownership of the pointer. + let expectedBaseAddress = try #require( + array.data.withUnsafeMutableBufferPointer { buffer in + buffer.baseAddress + }, + "Expected a valid base address") + array.withUnsafeMutableData { dataPointer in + #expect(dataPointer.baseAddress == expectedBaseAddress) + } + } } diff --git a/Tests/HomomorphicEncryptionTests/HeAPITests.swift b/Tests/HomomorphicEncryptionTests/HeAPITests.swift index 4d3b869e..79d643cc 100644 --- a/Tests/HomomorphicEncryptionTests/HeAPITests.swift +++ b/Tests/HomomorphicEncryptionTests/HeAPITests.swift @@ -19,7 +19,7 @@ import Testing @Suite struct HeAPITests { private struct TestEnv { - let context: Context + let context: Scheme.Context let data1: [Scheme.Scalar] let data2: [Scheme.Scalar] let coeffPlaintext1: Plaintext @@ -33,7 +33,7 @@ struct HeAPITests { let evaluationKey: EvaluationKey? init( - context: Context, + context: Scheme.Context, format: EncodeFormat, galoisElements: [Int] = [], relinearizationKey: Bool = false) throws @@ -87,27 +87,28 @@ struct HeAPITests { @Test func noOpScheme() async throws { let context: Context = try TestUtils.getTestContext() - try HeAPITestHelpers.schemeEncodeDecodeTest(context: context) - try HeAPITestHelpers.schemeEncryptDecryptTest(context: context) - try HeAPITestHelpers.schemeEncryptZeroDecryptTest(context: context) - try HeAPITestHelpers.schemeEncryptZeroAddDecryptTest(context: context) - try HeAPITestHelpers.schemeEncryptZeroMultiplyDecryptTest(context: context) - try await HeAPITestHelpers.schemeCiphertextAddTest(context: context) - try await HeAPITestHelpers.schemeCiphertextSubtractTest(context: context) - try await HeAPITestHelpers.schemeCiphertextCiphertextMultiplyTest(context: context) - try await HeAPITestHelpers.schemeCiphertextPlaintextAddTest(context: context) - try await HeAPITestHelpers.schemeCiphertextPlaintextSubtractTest(context: context) - try await HeAPITestHelpers.schemeCiphertextPlaintextMultiplyTest(context: context) - try await HeAPITestHelpers.schemeCiphertextMultiplyAddTest(context: context) - try HeAPITestHelpers.schemeCiphertextMultiplyAddPlainTest(context: context) - try await HeAPITestHelpers.schemeCiphertextMultiplySubtractTest(context: context) - try await HeAPITestHelpers.schemeCiphertextMultiplySubtractPlainTest(context: context) - try await HeAPITestHelpers.schemeCiphertextNegateTest(context: context) - try await HeAPITestHelpers.schemeCiphertextPlaintextInnerProductTest(context: context) - try await HeAPITestHelpers.schemeCiphertextCiphertextInnerProductTest(context: context) + try HeAPITestHelpers.schemeEncodeDecodeTest(context: context, scheme: NoOpScheme.self) + try HeAPITestHelpers.schemeEncryptDecryptTest(context: context, scheme: NoOpScheme.self) + try HeAPITestHelpers.schemeEncryptZeroDecryptTest(context: context, scheme: NoOpScheme.self) + try HeAPITestHelpers.schemeEncryptZeroAddDecryptTest(context: context, scheme: NoOpScheme.self) + try HeAPITestHelpers.schemeEncryptZeroMultiplyDecryptTest(context: context, scheme: NoOpScheme.self) + try await HeAPITestHelpers.schemeCiphertextAdditionTest(context: context, scheme: NoOpScheme.self) + try await HeAPITestHelpers.schemeCiphertextSubtractionTest(context: context, scheme: NoOpScheme.self) + try await HeAPITestHelpers.schemeCiphertextCiphertextMultiplicationTest( + context: context, + scheme: NoOpScheme.self) + try await HeAPITestHelpers.schemeCiphertextPlaintextAdditionTest(context: context, scheme: NoOpScheme.self) + try await HeAPITestHelpers.schemeCiphertextPlaintextSubtractionTest(context: context, scheme: NoOpScheme.self) + try await HeAPITestHelpers.schemeCiphertextPlaintextMultiplicationTest( + context: context, + scheme: NoOpScheme.self) + try await HeAPITestHelpers.schemeCiphertextMultiplyAddTest(context: context, scheme: NoOpScheme.self) + try HeAPITestHelpers.schemeCiphertextMultiplyAddPlainTest(context: context, scheme: NoOpScheme.self) + try await HeAPITestHelpers.schemeCiphertextMultiplySubTest(context: context, scheme: NoOpScheme.self) + try await HeAPITestHelpers.schemeCiphertextNegateTest(context: context, scheme: NoOpScheme.self) try HeAPITestHelpers.schemeEvaluationKeyTest(context: context) - try await HeAPITestHelpers.schemeRotationTest(context: context) - try await HeAPITestHelpers.schemeApplyGaloisTest(context: context) + try await HeAPITestHelpers.schemeRotationTest(context: context, scheme: NoOpScheme.self) + try await HeAPITestHelpers.schemeApplyGaloisTest(context: context, scheme: NoOpScheme.self) } private func bfvTestKeySwitching(context: Context>) throws { @@ -115,18 +116,21 @@ struct HeAPITests { return } - let testEnv = try TestEnv(context: context, format: .coefficient) + let testEnv = try HeAPITestHelpers.TestEnv>(context: context, format: .coefficient) let newSecretKey = try context.generateSecretKey() - let keySwitchKey = try Bfv.generateKeySwitchKey(context: context, - currentKey: testEnv.secretKey.poly, - targetKey: newSecretKey) - var switchedPolys = try Bfv.computeKeySwitchingUpdate( + let keySwitchKey = try Bfv._generateKeySwitchKey(context: context, + currentKey: testEnv.secretKey.poly, + targetKey: newSecretKey) + var switchedPolys = try Bfv._computeKeySwitchingUpdate( context: context, target: testEnv.ciphertext1.polys[1], keySwitchingKey: keySwitchKey) switchedPolys[0] += testEnv.ciphertext1.polys[0] - let switchedCiphertext = Ciphertext(context: context, polys: switchedPolys, correctionFactor: 1) + let switchedCiphertext = try Ciphertext, Coeff>( + context: context, + polys: switchedPolys, + correctionFactor: 1) let plaintext = try switchedCiphertext.decrypt(using: newSecretKey) let decrypted: [T] = try plaintext.decode(format: .coefficient) @@ -149,6 +153,7 @@ struct HeAPITests { coefficientModuli: HeAPITestHelpers.testCoefficientModuli(), errorStdDev: ErrorStdDev.stdDev32, securityLevel: SecurityLevel.unchecked) + let manyModuli = try EncryptionParameters( polyDegree: TestUtils.testPolyDegree, plaintextModulus: T.generatePrimes( @@ -164,31 +169,41 @@ struct HeAPITests { for encryptionParameters in predefined + [custom, manyModuli] { let context = try Context>(encryptionParameters: encryptionParameters) - try HeAPITestHelpers.schemeEncodeDecodeTest(context: context) - try HeAPITestHelpers.schemeEncryptDecryptTest(context: context) - try HeAPITestHelpers.schemeEncryptZeroDecryptTest(context: context) - try HeAPITestHelpers.schemeEncryptZeroAddDecryptTest(context: context) - try HeAPITestHelpers.schemeEncryptZeroMultiplyDecryptTest(context: context) - try await HeAPITestHelpers.schemeCiphertextAddTest(context: context) - try await HeAPITestHelpers.schemeCiphertextSubtractTest(context: context) - try await HeAPITestHelpers.schemeCiphertextCiphertextMultiplyTest(context: context) - try await HeAPITestHelpers.schemeCiphertextPlaintextAddTest(context: context) - try await HeAPITestHelpers.schemeCiphertextPlaintextSubtractTest(context: context) - try await HeAPITestHelpers.schemeCiphertextPlaintextMultiplyTest(context: context) - try await HeAPITestHelpers.schemeCiphertextMultiplyAddTest(context: context) - try HeAPITestHelpers.schemeCiphertextMultiplyAddPlainTest(context: context) - try await HeAPITestHelpers.schemeCiphertextMultiplySubtractTest(context: context) - try await HeAPITestHelpers.schemeCiphertextMultiplySubtractPlainTest(context: context) - try await HeAPITestHelpers.schemeCiphertextNegateTest(context: context) - try await HeAPITestHelpers.schemeCiphertextPlaintextInnerProductTest(context: context) - try await HeAPITestHelpers.schemeCiphertextCiphertextInnerProductTest(context: context) + try HeAPITestHelpers.schemeEncodeDecodeTest(context: context, scheme: Bfv.self) + try HeAPITestHelpers.schemeEncryptDecryptTest(context: context, scheme: Bfv.self) + try HeAPITestHelpers.schemeEncryptZeroDecryptTest(context: context, scheme: Bfv.self) + try HeAPITestHelpers.schemeEncryptZeroAddDecryptTest(context: context, scheme: Bfv.self) + try HeAPITestHelpers.schemeEncryptZeroMultiplyDecryptTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextAdditionTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextSubtractionTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextPlaintextAdditionTest(context: context, scheme: Bfv.self) + // swiftformat:disable wrap wrapArguments + try await HeAPITestHelpers.schemeCiphertextPlaintextSubtractionTest(context: context, scheme: Bfv.self) + // swiftlint:disable line_length + try await HeAPITestHelpers.schemeCiphertextPlaintextMultiplicationTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextMultiplySubtractPlainTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextPlaintextMultiplyAddPlainTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextPlaintextMultiplySubtractPlainTest(context: context, scheme: Bfv.self) + try HeAPITestHelpers.schemeCiphertextMultiplyAddPlainTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextMultiplySubtractPlainTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextPlaintextMultiplyAddPlainTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextPlaintextMultiplySubtractPlainTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextCiphertextMultiplicationTest(context: context, scheme: Bfv.self) + // swiftlint:enable line_length + try await HeAPITestHelpers.schemeCiphertextPlaintextInnerProductTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextCiphertextInnerProductTest(context: context, scheme: Bfv.self) + // swiftformat:enable wrap wrapArguments + try await HeAPITestHelpers.schemeCiphertextMultiplyAddTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeCiphertextNegateTest(context: context, scheme: Bfv.self) try HeAPITestHelpers.schemeEvaluationKeyTest(context: context) - try await HeAPITestHelpers.schemeRotationTest(context: context) - try await HeAPITestHelpers.schemeApplyGaloisTest(context: context) + try await HeAPITestHelpers.schemeApplyGaloisTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeRotationTest(context: context, scheme: Bfv.self) try bfvTestKeySwitching(context: context) - try HeAPITestHelpers.noiseBudgetTest(context: context) - try await HeAPITestHelpers.repeatedAdditionTest(context: context) - try await HeAPITestHelpers.multiplyInverseTest(context: context) + try HeAPITestHelpers.noiseBudgetTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.repeatAdditionTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.multiplyInverseTest(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeTestNtt(context: context, scheme: Bfv.self) + try await HeAPITestHelpers.schemeTestFormats(context: context, scheme: Bfv.self) } } diff --git a/Tests/HomomorphicEncryptionTests/PolyRqTests/GaloisTests.swift b/Tests/HomomorphicEncryptionTests/PolyRqTests/GaloisTests.swift index 10e6a77d..447b5519 100644 --- a/Tests/HomomorphicEncryptionTests/PolyRqTests/GaloisTests.swift +++ b/Tests/HomomorphicEncryptionTests/PolyRqTests/GaloisTests.swift @@ -127,7 +127,7 @@ struct GaloisTests { // No plan found. let supportedSteps = [2, 4, 8] for step in [1, 3, 5, 7, 9, 11, 13, 15] { - let plan = try GaloisElement.planMultiStep(supportedSteps: supportedSteps, step: step, degree: degree) + let plan = try GaloisElement._planMultiStep(supportedSteps: supportedSteps, step: step, degree: degree) #expect(plan == nil) } } @@ -162,7 +162,10 @@ struct GaloisTests { ] for (step, counts) in knownAnswers { - var result = try GaloisElement.planMultiStep(supportedSteps: supportedSteps, step: step, degree: degree) + var result = try GaloisElement._planMultiStep( + supportedSteps: supportedSteps, + step: step, + degree: degree) #expect(result == counts) // Negative steps yields same plan, just with negative rotations. @@ -171,7 +174,7 @@ struct GaloisTests { (transformNegative(step), count) }) - result = try GaloisElement.planMultiStep( + result = try GaloisElement._planMultiStep( supportedSteps: negativeSteps, step: negativeStep, degree: degree) @@ -188,7 +191,10 @@ struct GaloisTests { ] for (step, counts) in knownAnswers { - let result = try GaloisElement.planMultiStep(supportedSteps: supportedSteps, step: step, degree: degree) + let result = try GaloisElement._planMultiStep( + supportedSteps: supportedSteps, + step: step, + degree: degree) #expect(result == counts) } } @@ -212,7 +218,7 @@ struct GaloisTests { (transformPositive(-192), [transformPositive(-16): 12]), ] for (step, counts) in knownAnswers { - let result = try GaloisElement.planMultiStep(supportedSteps: steps, step: step, degree: degree) + let result = try GaloisElement._planMultiStep(supportedSteps: steps, step: step, degree: degree) #expect(result == counts) } } diff --git a/Tests/HomomorphicEncryptionTests/RnsBaseConverterTests.swift b/Tests/HomomorphicEncryptionTests/RnsBaseConverterTests.swift index 41054c0c..1469256a 100644 --- a/Tests/HomomorphicEncryptionTests/RnsBaseConverterTests.swift +++ b/Tests/HomomorphicEncryptionTests/RnsBaseConverterTests.swift @@ -17,7 +17,7 @@ import _TestUtilities import Testing @Suite -struct RnsBaseConverterTests { +struct _RnsBaseConverterTests { @Test func convertApproximate() throws { func runTestConvertApproximate( @@ -35,7 +35,7 @@ struct RnsBaseConverterTests { let inputContext = try PolyContext(degree: degree, moduli: inputModuli) let outputContext = try PolyContext(degree: degree, moduli: [t]) let referenceX = (0..(from: inputContext, to: outputContext) + let rnsBaseConverter = try _RnsBaseConverter(from: inputContext, to: outputContext) let data = referenceX.map { bigInt in TestUtils.crtDecompose(value: bigInt, moduli: inputContext.moduli) } let inputData = Array2d(data: data).transposed() @@ -81,7 +81,7 @@ struct RnsBaseConverterTests { let outputContext = try PolyContext(degree: degree, moduli: [T(2)]) // Arbitrary let poly: PolyRq = PolyRq.random(context: inputContext) - let rnsBaseConverter = try RnsBaseConverter(from: inputContext, to: outputContext) + let rnsBaseConverter = try _RnsBaseConverter(from: inputContext, to: outputContext) let composed: [QuadWidth] = try rnsBaseConverter.crtCompose(poly: poly) for (coeffIndex, composed) in composed.enumerated() { diff --git a/Tests/HomomorphicEncryptionTests/RnsToolTests.swift b/Tests/HomomorphicEncryptionTests/RnsToolTests.swift index 6c98e0a0..c706a4a2 100644 --- a/Tests/HomomorphicEncryptionTests/RnsToolTests.swift +++ b/Tests/HomomorphicEncryptionTests/RnsToolTests.swift @@ -27,7 +27,7 @@ struct RnsToolTests { { let inputContext = try PolyContext(degree: degree, moduli: inputModuli) let outputContext = try PolyContext(degree: degree, moduli: [outputModulus]) - let rnsTool = try RnsTool(from: inputContext, to: outputContext) + let rnsTool = try _RnsTool(from: inputContext, to: outputContext) let q: T = inputModuli.product() let k = inputModuli.count @@ -81,7 +81,7 @@ struct RnsToolTests { let inputContext = try PolyContext(degree: degree, moduli: inputModuli) let outputContext = try PolyContext(degree: degree, moduli: [t]) - let rnsTool = try RnsTool(from: inputContext, to: outputContext) + let rnsTool = try _RnsTool(from: inputContext, to: outputContext) let referenceX = (0...random(in: 0.. = PolyRq(context: bSkMtildeContext, data: inputData) @@ -176,7 +176,7 @@ struct RnsToolTests { let inputModuli = try T.generatePrimes(significantBitCounts: significantBitCounts, preferringSmall: true) let inputContext = try PolyContext(degree: degree, moduli: inputModuli) let outputContext = try PolyContext(degree: degree, moduli: [T(2)]) // arbitrary - let rnsTool = try RnsTool(from: inputContext, to: outputContext) + let rnsTool = try _RnsTool(from: inputContext, to: outputContext) let q: OctoWidth = inputModuli.product() let referenceX = (0...random(in: 0.. = inputModuli.product() let qBsk: OctoWidth = rnsTool.qBskContext.moduli.product() let bSk: OctoWidth = rnsTool.rnsConvertQToBSk.outputContext.moduli.product() @@ -274,7 +274,7 @@ struct RnsToolTests { let inputContext = try PolyContext(degree: degree, moduli: inputModuli) let outputContext = try PolyContext(degree: degree, moduli: [2]) // Arbitrary - let rnsTool = try RnsTool(from: inputContext, to: outputContext) + let rnsTool = try _RnsTool(from: inputContext, to: outputContext) let bskContext = rnsTool.rnsConvertQToBSk.outputContext let bskModuli = bskContext.moduli let bskProd: OctoWidth = bskModuli.product() diff --git a/Tests/HomomorphicEncryptionTests/SerializationTests.swift b/Tests/HomomorphicEncryptionTests/SerializationTests.swift index c8f2cd1d..3aa69fef 100644 --- a/Tests/HomomorphicEncryptionTests/SerializationTests.swift +++ b/Tests/HomomorphicEncryptionTests/SerializationTests.swift @@ -26,7 +26,7 @@ struct SerializationTests { } func runTest(_: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0..(_: Scheme.Type, format: EncodeFormat) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0..(_: Scheme.Type, format: EncodeFormat) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0..(_: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let secretKey = try context.generateSecretKey() let serialized = secretKey.serialize() - let deserialized = try SecretKey(deserialize: serialized, context: context) + let deserialized = try SecretKey(deserialize: serialized, context: context) #expect(deserialized == secretKey) } @@ -166,13 +166,13 @@ struct SerializationTests { @Test func galoisKey() throws { func runTest(_: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let secretKey = try context.generateSecretKey() let evaluationKeyConfig = EvaluationKeyConfig(galoisElements: [3, 5, 7]) let evaluationKey = try context.generateEvaluationKey(config: evaluationKeyConfig, using: secretKey) let galoisKey = try #require(evaluationKey.galoisKey) let serialized = galoisKey.serialize() - let deserialized = try GaloisKey(deserialize: serialized, context: context) + let deserialized = try _GaloisKey(deserialize: serialized, context: context) #expect(deserialized == galoisKey) } @@ -184,13 +184,14 @@ struct SerializationTests { @Test func relinearizationKey() throws { func runTest(_: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let secretKey = try context.generateSecretKey() let evaluationKeyConfig = EvaluationKeyConfig(hasRelinearizationKey: true) - let evaluationKey = try context.generateEvaluationKey(config: evaluationKeyConfig, using: secretKey) + let evaluationKey = try context.generateEvaluationKey(config: evaluationKeyConfig, + using: secretKey) let relinearizationKey = try #require(evaluationKey.relinearizationKey) let serialized = relinearizationKey.serialize() - let deserialized = try RelinearizationKey(deserialize: serialized, context: context) + let deserialized = try _RelinearizationKey(deserialize: serialized, context: context) #expect(deserialized == relinearizationKey) } @@ -202,14 +203,15 @@ struct SerializationTests { @Test func evaluationKey() throws { func runTest(_: Scheme.Type) throws { - let context: Context = try TestUtils.getTestContext() + let context: Scheme.Context = try TestUtils.getTestContext() let secretKey = try context.generateSecretKey() let evaluationKeyConfig = EvaluationKeyConfig( galoisElements: [3, 5, 7], hasRelinearizationKey: true) - let evaluationKey = try context.generateEvaluationKey(config: evaluationKeyConfig, using: secretKey) + let evaluationKey = try context.generateEvaluationKey(config: evaluationKeyConfig, + using: secretKey) let serialized = evaluationKey.serialize() - let deserialized = try EvaluationKey(deserialize: serialized, context: context) + let deserialized = try EvaluationKey(deserialize: serialized, context: context) #expect(deserialized == evaluationKey) func checkSeededCiphertext(_ ciphertexts: [SerializedCiphertext]) { diff --git a/Tests/PrivateInformationRetrievalTests/ExpansionTests.swift b/Tests/PrivateInformationRetrievalTests/ExpansionTests.swift index a3943448..799ae23d 100644 --- a/Tests/PrivateInformationRetrievalTests/ExpansionTests.swift +++ b/Tests/PrivateInformationRetrievalTests/ExpansionTests.swift @@ -14,29 +14,29 @@ import _TestUtilities import HomomorphicEncryption -import PrivateInformationRetrieval +@testable import PrivateInformationRetrieval import Testing @Suite struct ExpansionTests { @Test(arguments: PirKeyCompressionStrategy.allCases) - func expandCiphertextForOneStepTest(keyCompression: PirKeyCompressionStrategy) throws { - try PirTestUtils.ExpansionTests.expandCiphertextForOneStep(scheme: NoOpScheme.self, keyCompression) - try PirTestUtils.ExpansionTests.expandCiphertextForOneStep(scheme: Bfv.self, keyCompression) - try PirTestUtils.ExpansionTests.expandCiphertextForOneStep(scheme: Bfv.self, keyCompression) + func expandCiphertextForOneStepTest(keyCompression: PirKeyCompressionStrategy) async throws { + try await PirTestUtils.ExpansionTests.expandCiphertextForOneStep(scheme: NoOpScheme.self, keyCompression) + try await PirTestUtils.ExpansionTests.expandCiphertextForOneStep(scheme: Bfv.self, keyCompression) + try await PirTestUtils.ExpansionTests.expandCiphertextForOneStep(scheme: Bfv.self, keyCompression) } @Test - func oneCiphertextRoundtrip() throws { - try PirTestUtils.ExpansionTests.oneCiphertextRoundtrip(scheme: NoOpScheme.self) - try PirTestUtils.ExpansionTests.oneCiphertextRoundtrip(scheme: Bfv.self) - try PirTestUtils.ExpansionTests.oneCiphertextRoundtrip(scheme: Bfv.self) + func oneCiphertextRoundtrip() async throws { + try await PirTestUtils.ExpansionTests.oneCiphertextRoundtrip(scheme: NoOpScheme.self) + try await PirTestUtils.ExpansionTests.oneCiphertextRoundtrip(scheme: Bfv.self) + try await PirTestUtils.ExpansionTests.oneCiphertextRoundtrip(scheme: Bfv.self) } @Test - func multipleCiphertextsRoundtrip() throws { - try PirTestUtils.ExpansionTests.multipleCiphertextsRoundtrip(scheme: NoOpScheme.self) - try PirTestUtils.ExpansionTests.multipleCiphertextsRoundtrip(scheme: Bfv.self) - try PirTestUtils.ExpansionTests.multipleCiphertextsRoundtrip(scheme: Bfv.self) + func multipleCiphertextsRoundtrip() async throws { + try await PirTestUtils.ExpansionTests.multipleCiphertextsRoundtrip(pirUtil: PirUtil.self) + try await PirTestUtils.ExpansionTests.multipleCiphertextsRoundtrip(pirUtil: PirUtil>.self) + try await PirTestUtils.ExpansionTests.multipleCiphertextsRoundtrip(pirUtil: PirUtil>.self) } } diff --git a/Tests/PrivateInformationRetrievalTests/IndexPirTests.swift b/Tests/PrivateInformationRetrievalTests/IndexPirTests.swift index f6c7e8db..15b2fd54 100644 --- a/Tests/PrivateInformationRetrievalTests/IndexPirTests.swift +++ b/Tests/PrivateInformationRetrievalTests/IndexPirTests.swift @@ -14,14 +14,15 @@ import _TestUtilities import HomomorphicEncryption -import PrivateInformationRetrieval +@testable import PrivateInformationRetrieval import Testing @Suite struct IndexPirTests { @Test func generateParameter() throws { - let context: Context> = try TestUtils.getTestContext() + typealias Scheme = Bfv + let context: Context = try TestUtils.getTestContext() // unevenDimensions: false do { let config = try IndexPirConfig(entryCount: 16, @@ -30,7 +31,7 @@ struct IndexPirTests { batchSize: 1, unevenDimensions: false, keyCompression: .noCompression) - let parameter = MulPir>.generateParameter(config: config, with: context) + let parameter = MulPir.generateParameter(config: config, with: context) #expect(parameter.dimensions == [4, 4]) } do { @@ -40,7 +41,7 @@ struct IndexPirTests { batchSize: 2, unevenDimensions: false, keyCompression: .noCompression) - let parameter = MulPir>.generateParameter(config: config, with: context) + let parameter = MulPir.generateParameter(config: config, with: context) #expect(parameter.dimensions == [4, 3]) } // unevenDimensions: true @@ -51,7 +52,7 @@ struct IndexPirTests { batchSize: 1, unevenDimensions: true, keyCompression: .noCompression) - let parameter = MulPir>.generateParameter(config: config, with: context) + let parameter = MulPir.generateParameter(config: config, with: context) #expect(parameter.dimensions == [5, 3]) } do { @@ -61,7 +62,7 @@ struct IndexPirTests { batchSize: 2, unevenDimensions: true, keyCompression: .noCompression) - let parameter = MulPir>.generateParameter(config: config, with: context) + let parameter = MulPir.generateParameter(config: config, with: context) #expect(parameter.dimensions == [5, 3]) } do { @@ -71,7 +72,7 @@ struct IndexPirTests { batchSize: 2, unevenDimensions: true, keyCompression: .noCompression) - let parameter = MulPir>.generateParameter(config: config, with: context) + let parameter = MulPir.generateParameter(config: config, with: context) #expect(parameter.dimensions == [9, 2]) } // no key compression @@ -82,7 +83,7 @@ struct IndexPirTests { batchSize: 2, unevenDimensions: true, keyCompression: .noCompression) - let parameter = MulPir>.generateParameter(config: config, with: context) + let parameter = MulPir.generateParameter(config: config, with: context) let evalKeyConfig = EvaluationKeyConfig( galoisElements: [3, 5, 9, 17], hasRelinearizationKey: true) @@ -96,7 +97,7 @@ struct IndexPirTests { batchSize: 2, unevenDimensions: true, keyCompression: .hybridCompression) - let parameter = MulPir>.generateParameter(config: config, with: context) + let parameter = MulPir.generateParameter(config: config, with: context) let evalKeyConfig = EvaluationKeyConfig( galoisElements: [3, 5, 9, 17], hasRelinearizationKey: true) @@ -110,7 +111,7 @@ struct IndexPirTests { batchSize: 2, unevenDimensions: true, keyCompression: .maxCompression) - let parameter = MulPir>.generateParameter(config: config, with: context) + let parameter = MulPir.generateParameter(config: config, with: context) let evalKeyConfig = EvaluationKeyConfig( galoisElements: [3, 5, 9], hasRelinearizationKey: true) @@ -119,9 +120,9 @@ struct IndexPirTests { } @Test - func indexPir() throws { - try PirTestUtils.IndexPirTests.indexPir(scheme: NoOpScheme.self) - try PirTestUtils.IndexPirTests.indexPir(scheme: Bfv.self) - try PirTestUtils.IndexPirTests.indexPir(scheme: Bfv.self) + func indexPir() async throws { + try await PirTestUtils.IndexPirTests.indexPir(scheme: NoOpScheme.self) + try await PirTestUtils.IndexPirTests.indexPir(scheme: Bfv.self) + try await PirTestUtils.IndexPirTests.indexPir(scheme: Bfv.self) } } diff --git a/Tests/PrivateInformationRetrievalTests/KeywordPirTests.swift b/Tests/PrivateInformationRetrievalTests/KeywordPirTests.swift index fc049f46..53f0ae0c 100644 --- a/Tests/PrivateInformationRetrievalTests/KeywordPirTests.swift +++ b/Tests/PrivateInformationRetrievalTests/KeywordPirTests.swift @@ -20,79 +20,79 @@ import Testing @Suite struct KeywordPirTests { @Test - func processedDatabaseSerialization() throws { - try PirTestUtils.KeywordPirTests.processedDatabaseSerialization(Bfv.self) - try PirTestUtils.KeywordPirTests.processedDatabaseSerialization(Bfv.self) + func processedDatabaseSerialization() async throws { + try await PirTestUtils.KeywordPirTests.processedDatabaseSerialization(Bfv.self) + try await PirTestUtils.KeywordPirTests.processedDatabaseSerialization(Bfv.self) } @Test - func keywordPirMulPir1HashFunction() throws { - try PirTestUtils.KeywordPirTests.keywordPirMulPir1HashFunction(NoOpScheme.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPir1HashFunction(Bfv.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPir1HashFunction(Bfv.self) + func keywordPirMulPir1HashFunction() async throws { + try await PirTestUtils.KeywordPirTests.keywordPirMulPir1HashFunction(NoOpScheme.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPir1HashFunction(Bfv.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPir1HashFunction(Bfv.self) } @Test - func keywordPirMulPir3HashFunctions() throws { - try PirTestUtils.KeywordPirTests.keywordPirMulPir3HashFunctions(NoOpScheme.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPir3HashFunctions(Bfv.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPir3HashFunctions(Bfv.self) + func keywordPirMulPir3HashFunctions() async throws { + try await PirTestUtils.KeywordPirTests.keywordPirMulPir3HashFunctions(NoOpScheme.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPir3HashFunctions(Bfv.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPir3HashFunctions(Bfv.self) } @Test - func keywordPirMulPir1Dimension() throws { - try PirTestUtils.KeywordPirTests.keywordPirMulPir1Dimension(NoOpScheme.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPir1Dimension(Bfv.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPir1Dimension(Bfv.self) + func keywordPirMulPir1Dimension() async throws { + try await PirTestUtils.KeywordPirTests.keywordPirMulPir1Dimension(NoOpScheme.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPir1Dimension(Bfv.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPir1Dimension(Bfv.self) } @Test - func keywordPirMulPir2Dimensions() throws { - try PirTestUtils.KeywordPirTests.keywordPirMulPir2Dimensions(NoOpScheme.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPir2Dimensions(Bfv.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPir2Dimensions(Bfv.self) + func keywordPirMulPir2Dimensions() async throws { + try await PirTestUtils.KeywordPirTests.keywordPirMulPir2Dimensions(NoOpScheme.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPir2Dimensions(Bfv.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPir2Dimensions(Bfv.self) } @Test - func keywordPirMulPirHybridKeyCompression() throws { - try PirTestUtils.KeywordPirTests.keywordPirMulPirHybridKeyCompression(NoOpScheme.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPirHybridKeyCompression(Bfv.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPirHybridKeyCompression(Bfv.self) + func keywordPirMulPirHybridKeyCompression() async throws { + try await PirTestUtils.KeywordPirTests.keywordPirMulPirHybridKeyCompression(NoOpScheme.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPirHybridKeyCompression(Bfv.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPirHybridKeyCompression(Bfv.self) } @Test - func keywordPirMulPirMaxKeyCompression() throws { - try PirTestUtils.KeywordPirTests.keywordPirMulPirMaxKeyCompression(NoOpScheme.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPirMaxKeyCompression(Bfv.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPirMaxKeyCompression(Bfv.self) + func keywordPirMulPirMaxKeyCompression() async throws { + try await PirTestUtils.KeywordPirTests.keywordPirMulPirMaxKeyCompression(NoOpScheme.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPirMaxKeyCompression(Bfv.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPirMaxKeyCompression(Bfv.self) } @Test - func keywordPirMulPirLargeParameters() throws { - try PirTestUtils.KeywordPirTests.keywordPirMulPirLargeParameters(NoOpScheme.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPirLargeParameters(Bfv.self) - try PirTestUtils.KeywordPirTests.keywordPirMulPirLargeParameters(Bfv.self) + func keywordPirMulPirLargeParameters() async throws { + try await PirTestUtils.KeywordPirTests.keywordPirMulPirLargeParameters(NoOpScheme.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPirLargeParameters(Bfv.self) + try await PirTestUtils.KeywordPirTests.keywordPirMulPirLargeParameters(Bfv.self) } @Test - func keywordPirFixedConfig() throws { - try PirTestUtils.KeywordPirTests.keywordPirFixedConfig(NoOpScheme.self) - try PirTestUtils.KeywordPirTests.keywordPirFixedConfig(Bfv.self) - try PirTestUtils.KeywordPirTests.keywordPirFixedConfig(Bfv.self) + func keywordPirFixedConfig() async throws { + try await PirTestUtils.KeywordPirTests.keywordPirFixedConfig(NoOpScheme.self) + try await PirTestUtils.KeywordPirTests.keywordPirFixedConfig(Bfv.self) + try await PirTestUtils.KeywordPirTests.keywordPirFixedConfig(Bfv.self) } @Test - func sharding() throws { + func sharding() async throws { // TODO: make compatible with NoOpScheme - try PirTestUtils.KeywordPirTests.sharding(Bfv.self) - try PirTestUtils.KeywordPirTests.sharding(Bfv.self) + try await PirTestUtils.KeywordPirTests.sharding(PirUtil>.self) + try await PirTestUtils.KeywordPirTests.sharding(PirUtil>.self) } @Test - func limitEntriesPerResponse() throws { + func limitEntriesPerResponse() async throws { // TODO: make compatible with NoOpScheme. - try PirTestUtils.KeywordPirTests.limitEntriesPerResponse(Bfv.self) - try PirTestUtils.KeywordPirTests.limitEntriesPerResponse(Bfv.self) + try await PirTestUtils.KeywordPirTests.limitEntriesPerResponse(Bfv.self) + try await PirTestUtils.KeywordPirTests.limitEntriesPerResponse(Bfv.self) } @Test diff --git a/Tests/PrivateInformationRetrievalTests/MulPirTests.swift b/Tests/PrivateInformationRetrievalTests/MulPirTests.swift index b96e7fa8..9257b9a4 100644 --- a/Tests/PrivateInformationRetrievalTests/MulPirTests.swift +++ b/Tests/PrivateInformationRetrievalTests/MulPirTests.swift @@ -14,7 +14,7 @@ import _TestUtilities import HomomorphicEncryption -import PrivateInformationRetrieval +@testable import PrivateInformationRetrieval import Testing @Suite @@ -27,16 +27,16 @@ struct MulPirTests { } @Test(arguments: PirKeyCompressionStrategy.allCases) - func queryGeneration(keyCompression: PirKeyCompressionStrategy) throws { - try PirTestUtils.MulPirTests.queryGenerationTest(scheme: NoOpScheme.self, keyCompression) - try PirTestUtils.MulPirTests.queryGenerationTest(scheme: Bfv.self, keyCompression) - try PirTestUtils.MulPirTests.queryGenerationTest(scheme: Bfv.self, keyCompression) + func queryGeneration(keyCompression: PirKeyCompressionStrategy) async throws { + try await PirTestUtils.MulPirTests.queryGenerationTest(pirUtil: PirUtil.self, keyCompression) + try await PirTestUtils.MulPirTests.queryGenerationTest(pirUtil: PirUtil>.self, keyCompression) + try await PirTestUtils.MulPirTests.queryGenerationTest(pirUtil: PirUtil>.self, keyCompression) } @Test func computeCoordinates() throws { - try PirTestUtils.MulPirTests.computeCoordinates(scheme: NoOpScheme.self) - try PirTestUtils.MulPirTests.computeCoordinates(scheme: Bfv.self) - try PirTestUtils.MulPirTests.computeCoordinates(scheme: Bfv.self) + try PirTestUtils.MulPirTests.computeCoordinates(pirUtil: PirUtil.self) + try PirTestUtils.MulPirTests.computeCoordinates(pirUtil: PirUtil>.self) + try PirTestUtils.MulPirTests.computeCoordinates(pirUtil: PirUtil>.self) } } diff --git a/Tests/PrivateInformationRetrievalTests/SymmetricPIRTests.swift b/Tests/PrivateInformationRetrievalTests/SymmetricPIRTests.swift index a17c6215..b15ac61f 100644 --- a/Tests/PrivateInformationRetrievalTests/SymmetricPIRTests.swift +++ b/Tests/PrivateInformationRetrievalTests/SymmetricPIRTests.swift @@ -17,7 +17,7 @@ import _TestUtilities import Crypto import Foundation import HomomorphicEncryption -@testable import PrivateInformationRetrieval +import PrivateInformationRetrieval import Testing @Suite @@ -97,8 +97,8 @@ struct SymmetricPirTests { } @Test - func roundTrip() throws { - try PirTestUtils.SymmetricPirTests.roundTrip(Bfv.self) - try PirTestUtils.SymmetricPirTests.roundTrip(Bfv.self) + func roundTrip() async throws { + try await PirTestUtils.SymmetricPirTests.roundTrip(Bfv.self) + try await PirTestUtils.SymmetricPirTests.roundTrip(Bfv.self) } } diff --git a/Tests/PrivateNearestNeighborSearchTests/CiphertextMatrixTests.swift b/Tests/PrivateNearestNeighborSearchTests/CiphertextMatrixTests.swift index 58fce12b..a235fef4 100644 --- a/Tests/PrivateNearestNeighborSearchTests/CiphertextMatrixTests.swift +++ b/Tests/PrivateNearestNeighborSearchTests/CiphertextMatrixTests.swift @@ -19,23 +19,23 @@ import Testing @Suite struct CiphertextMatrixTests { @Test - func encryptDecryptRoundTrip() throws { - try PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.encryptDecryptRoundTrip(for: NoOpScheme.self) - try PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.encryptDecryptRoundTrip(for: Bfv.self) - try PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.encryptDecryptRoundTrip(for: Bfv.self) + func encryptDecryptRoundTrip() async throws { + try await PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.encryptDecryptRoundTrip(for: NoOpScheme.self) + try await PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.encryptDecryptRoundTrip(for: Bfv.self) + try await PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.encryptDecryptRoundTrip(for: Bfv.self) } @Test - func convertFormatRoundTrip() throws { + func convertFormatRoundTrip() async throws { try PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.convertFormatRoundTrip(for: NoOpScheme.self) try PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.convertFormatRoundTrip(for: Bfv.self) try PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.convertFormatRoundTrip(for: Bfv.self) } @Test - func extractDenseRow() throws { - try PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.extractDenseRow(for: NoOpScheme.self) - try PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.extractDenseRow(for: Bfv.self) - try PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.extractDenseRow(for: Bfv.self) + func extractDenseRow() async throws { + try await PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.extractDenseRow(for: NoOpScheme.self) + try await PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.extractDenseRow(for: Bfv.self) + try await PrivateNearestNeighborSearchUtil.CiphertextMatrixTests.extractDenseRow(for: Bfv.self) } } diff --git a/Tests/PrivateNearestNeighborSearchTests/ClientTests.swift b/Tests/PrivateNearestNeighborSearchTests/ClientTests.swift index 504683b1..baf6de84 100644 --- a/Tests/PrivateNearestNeighborSearchTests/ClientTests.swift +++ b/Tests/PrivateNearestNeighborSearchTests/ClientTests.swift @@ -38,8 +38,8 @@ struct ClientTests { } @Test - func clientServer() throws { - try PrivateNearestNeighborSearchUtil.ClientTests.clientServer(for: Bfv.self) - try PrivateNearestNeighborSearchUtil.ClientTests.clientServer(for: Bfv.self) + func clientServer() async throws { + try await PrivateNearestNeighborSearchUtil.ClientTests.clientServer(for: Bfv.self) + try await PrivateNearestNeighborSearchUtil.ClientTests.clientServer(for: Bfv.self) } } diff --git a/Tests/PrivateNearestNeighborSearchTests/DatabaseTests.swift b/Tests/PrivateNearestNeighborSearchTests/DatabaseTests.swift index fedec653..7718b7bc 100644 --- a/Tests/PrivateNearestNeighborSearchTests/DatabaseTests.swift +++ b/Tests/PrivateNearestNeighborSearchTests/DatabaseTests.swift @@ -19,9 +19,9 @@ import Testing @Suite struct DatabaseTests { @Test - func serializedProcessedDatabase() throws { - try PrivateNearestNeighborSearchUtil.DatabaseTests.serializedProcessedDatabase(for: NoOpScheme.self) - try PrivateNearestNeighborSearchUtil.DatabaseTests.serializedProcessedDatabase(for: Bfv.self) - try PrivateNearestNeighborSearchUtil.DatabaseTests.serializedProcessedDatabase(for: Bfv.self) + func serializedProcessedDatabase() async throws { + try await PrivateNearestNeighborSearchUtil.DatabaseTests.serializedProcessedDatabase(for: NoOpScheme.self) + try await PrivateNearestNeighborSearchUtil.DatabaseTests.serializedProcessedDatabase(for: Bfv.self) + try await PrivateNearestNeighborSearchUtil.DatabaseTests.serializedProcessedDatabase(for: Bfv.self) } } diff --git a/Tests/PrivateNearestNeighborSearchTests/MatrixMultiplicationTests.swift b/Tests/PrivateNearestNeighborSearchTests/MatrixMultiplicationTests.swift index 9c716c3c..6f422e65 100644 --- a/Tests/PrivateNearestNeighborSearchTests/MatrixMultiplicationTests.swift +++ b/Tests/PrivateNearestNeighborSearchTests/MatrixMultiplicationTests.swift @@ -19,21 +19,25 @@ import Testing @Suite struct MatrixMultiplicationTests { @Test - func mulVector() throws { - try PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests.mulVector(for: Bfv.self) - try PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests.mulVector(for: Bfv.self) + func mulVector() async throws { + try await PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests.mulVector(for: Bfv.self) + try await PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests.mulVector(for: Bfv.self) } @Test - func matrixMulSmallDimensions() throws { - try PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests.matrixMulSmallDimensions(for: Bfv.self) - try PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests.matrixMulSmallDimensions(for: Bfv.self) + func matrixMulSmallDimensions() async throws { + try await PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests + .matrixMulSmallDimensions(for: Bfv.self) + try await PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests + .matrixMulSmallDimensions(for: Bfv.self) } @Test - func matrixMulLargeDimensions() throws { - try PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests.matrixMulLargeDimensions(for: Bfv.self) - try PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests.matrixMulLargeDimensions(for: Bfv.self) + func matrixMulLargeDimensions() async throws { + try await PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests + .matrixMulLargeDimensions(for: Bfv.self) + try await PrivateNearestNeighborSearchUtil.MatrixMultiplicationTests + .matrixMulLargeDimensions(for: Bfv.self) } @Test