diff --git a/src/characters/cbd-recipient.ts b/src/characters/cbd-recipient.ts index 9a977665d..4997ab3fe 100644 --- a/src/characters/cbd-recipient.ts +++ b/src/characters/cbd-recipient.ts @@ -1,26 +1,22 @@ import { - AccessControlPolicy, - AuthenticatedData, - Ciphertext, combineDecryptionSharesSimple, Context, DecryptionShareSimple, - decryptWithSharedSecret, EncryptedThresholdDecryptionRequest, EncryptedThresholdDecryptionResponse, FerveoVariant, SessionSharedSecret, SessionStaticSecret, ThresholdDecryptionRequest, + ThresholdMessageKit, } from '@nucypher/nucypher-core'; import { ethers } from 'ethers'; -import { keccak256 } from 'ethers/lib/utils'; import { DkgCoordinatorAgent, DkgParticipant } from '../agents/coordinator'; import { ConditionExpression } from '../conditions'; -import { DkgClient, DkgRitual } from '../dkg'; +import { DkgRitual } from '../dkg'; import { PorterClient } from '../porter'; -import { fromJSON, toBytes, toJSON } from '../utils'; +import { fromJSON, objectEquals, toJSON } from '../utils'; export type ThresholdDecrypterJSON = { porterUri: string; @@ -47,51 +43,22 @@ export class ThresholdDecrypter { public async retrieveAndDecrypt( provider: ethers.providers.Web3Provider, conditionExpr: ConditionExpression, - ciphertext: Ciphertext + thresholdMessageKit: ThresholdMessageKit ): Promise { - const acp = await this.makeAcp(provider, conditionExpr, ciphertext); - const decryptionShares = await this.retrieve( provider, conditionExpr, - ciphertext, - acp + thresholdMessageKit ); - const sharedSecret = combineDecryptionSharesSimple(decryptionShares); - return decryptWithSharedSecret( - ciphertext, - conditionExpr.asAad(), - sharedSecret - ); - } - - private async makeAcp( - provider: ethers.providers.Web3Provider, - conditionExpr: ConditionExpression, - ciphertext: Ciphertext - ) { - const dkgRitual = await DkgClient.getExistingRitual( - provider, - this.ritualId - ); - const authData = new AuthenticatedData( - dkgRitual.dkgPublicKey, - conditionExpr.toWASMConditions() - ); - - const headerHash = keccak256(ciphertext.header.toBytes()); - const authorization = await provider.getSigner().signMessage(headerHash); - - return new AccessControlPolicy(authData, toBytes(authorization)); + return thresholdMessageKit.decryptWithSharedSecret(sharedSecret); } // Retrieve decryption shares public async retrieve( provider: ethers.providers.Web3Provider, conditionExpr: ConditionExpression, - ciphertext: Ciphertext, - acp: AccessControlPolicy + thresholdMessageKit: ThresholdMessageKit ): Promise { const dkgParticipants = await DkgCoordinatorAgent.getParticipants( provider, @@ -100,10 +67,9 @@ export class ThresholdDecrypter { const contextStr = await conditionExpr.buildContext(provider).toJson(); const { sharedSecrets, encryptedRequests } = this.makeDecryptionRequests( this.ritualId, - ciphertext, - contextStr, + new Context(contextStr), dkgParticipants, - acp + thresholdMessageKit ); const { encryptedResponses, errors } = await this.porter.cbdDecrypt( @@ -148,10 +114,9 @@ export class ThresholdDecrypter { private makeDecryptionRequests( ritualId: number, - ciphertext: Ciphertext, - contextStr: string, + conditionContext: Context, dkgParticipants: Array, - acp: AccessControlPolicy + thresholdMessageKit: ThresholdMessageKit ): { sharedSecrets: Record; encryptedRequests: Record; @@ -159,9 +124,9 @@ export class ThresholdDecrypter { const decryptionRequest = new ThresholdDecryptionRequest( ritualId, FerveoVariant.simple, - ciphertext.header, - acp, - new Context(contextStr) + thresholdMessageKit.ciphertextHeader, + thresholdMessageKit.acp, + conditionContext ); const ephemeralSessionKey = this.makeSessionKey(); @@ -228,8 +193,6 @@ export class ThresholdDecrypter { } public equals(other: ThresholdDecrypter): boolean { - return ( - this.porter.porterUrl.toString() === other.porter.porterUrl.toString() - ); + return objectEquals(this.toObj(), other.toObj()); } } diff --git a/src/characters/enrico.ts b/src/characters/enrico.ts index 17c6ff358..3185ca050 100644 --- a/src/characters/enrico.ts +++ b/src/characters/enrico.ts @@ -1,11 +1,13 @@ import { - Ciphertext, + AccessControlPolicy, DkgPublicKey, - ferveoEncrypt, + encryptForDkg, MessageKit, PublicKey, SecretKey, + ThresholdMessageKit, } from '@nucypher/nucypher-core'; +import { arrayify, keccak256 } from 'ethers/lib/utils'; import { ConditionExpression } from '../conditions'; import { Keyring } from '../keyring'; @@ -51,13 +53,13 @@ export class Enrico { public encryptMessageCbd( plaintext: Uint8Array | string, - withConditions?: ConditionExpression - ): { ciphertext: Ciphertext; aad: Uint8Array } { - if (!withConditions) { - withConditions = this.conditions; + conditions?: ConditionExpression + ): ThresholdMessageKit { + if (!conditions) { + conditions = this.conditions; } - if (!withConditions) { + if (!conditions) { throw new Error('Conditions are required for CBD encryption.'); } @@ -65,12 +67,19 @@ export class Enrico { throw new Error('Wrong key type. Use encryptMessagePre instead.'); } - const aad = withConditions.asAad(); - const ciphertext = ferveoEncrypt( + const [ciphertext, authenticatedData] = encryptForDkg( plaintext instanceof Uint8Array ? plaintext : toBytes(plaintext), - aad, - this.encryptingKey + this.encryptingKey, + conditions.toWASMConditions() + ); + + const headerHash = keccak256(ciphertext.header.toBytes()); + const authorization = this.keyring.signer.sign(arrayify(headerHash)); + const acp = new AccessControlPolicy( + authenticatedData, + authorization.toBEBytes() ); - return { ciphertext, aad }; + + return new ThresholdMessageKit(ciphertext, acp); } } diff --git a/test/integration/dkg-client.test.ts b/test/integration/dkg-client.test.ts index acb8337b6..283dc54f6 100644 --- a/test/integration/dkg-client.test.ts +++ b/test/integration/dkg-client.test.ts @@ -29,7 +29,7 @@ describe('DkgCoordinatorAgent', () => { it('fetches participants from the coordinator', async () => { const provider = fakeWeb3Provider(SecretKey.random().toBEBytes()); - const fakeParticipants = fakeDkgParticipants(fakeRitualId); + const fakeParticipants = await fakeDkgParticipants(fakeRitualId); const getParticipantsSpy = mockGetParticipants( fakeParticipants.participants ); diff --git a/test/unit/cbd-strategy.test.ts b/test/unit/cbd-strategy.test.ts index d38aeeb45..a3222d499 100644 --- a/test/unit/cbd-strategy.test.ts +++ b/test/unit/cbd-strategy.test.ts @@ -39,6 +39,7 @@ const conditionExpr = new ConditionExpression(ownsNFT); const ursulas = fakeUrsulas(); const variant = FerveoVariant.precomputed; const ritualId = 0; +const web3Provider = fakeWeb3Provider(aliceSecretKey.toBEBytes()); const makeCbdStrategy = async () => { const cohort = await makeCohort(ursulas); @@ -52,7 +53,6 @@ async function makeDeployedCbdStrategy() { const mockedDkg = fakeDkgFlow(variant, 0, 4, 4); const mockedDkgRitual = fakeDkgRitual(mockedDkg); - const web3Provider = fakeWeb3Provider(aliceSecretKey.toBEBytes()); const getUrsulasSpy = mockGetUrsulas(ursulas); const getExistingRitualSpy = mockGetExistingRitual(mockedDkgRitual); const deployedStrategy = await strategy.deploy(web3Provider, ritualId); @@ -102,20 +102,20 @@ describe('CbdDeployedStrategy', () => { const { mockedDkg, deployedStrategy } = await makeDeployedCbdStrategy(); const message = 'this is a secret'; - const { ciphertext, aad } = deployedStrategy + const thresholdMessageKit = deployedStrategy .makeEncrypter(conditionExpr) .encryptMessageCbd(message); // Setup mocks for `retrieveAndDecrypt` - const { decryptionShares } = fakeTDecFlow({ + const { decryptionShares } = await fakeTDecFlow({ ...mockedDkg, message: toBytes(message), - aad, - ciphertext, + conditionExpr, + dkgPublicKey: mockedDkg.dkg.publicKey(), + thresholdMessageKit, }); - const { participantSecrets, participants } = fakeDkgParticipants( - mockedDkg.ritualId, - variant + const { participantSecrets, participants } = await fakeDkgParticipants( + mockedDkg.ritualId ); const requesterSessionKey = SessionStaticSecret.random(); const decryptSpy = mockCbdDecrypt( @@ -132,7 +132,7 @@ describe('CbdDeployedStrategy', () => { await deployedStrategy.decrypter.retrieveAndDecrypt( aliceProvider, conditionExpr, - ciphertext + thresholdMessageKit ); expect(getUrsulasSpy).toHaveBeenCalled(); expect(getParticipantsSpy).toHaveBeenCalled(); diff --git a/test/utils.ts b/test/utils.ts index eec92f4cb..0a2c12bed 100644 --- a/test/utils.ts +++ b/test/utils.ts @@ -7,16 +7,15 @@ import { AggregatedTranscript, Capsule, CapsuleFrag, - Ciphertext, combineDecryptionSharesSimple, DecryptionSharePrecomputed, DecryptionShareSimple, - decryptWithSharedSecret, Dkg, + DkgPublicKey, EncryptedThresholdDecryptionResponse, EncryptedTreasureMap, + encryptForDkg, EthereumAddress, - ferveoEncrypt, FerveoVariant, Keypair, PublicKey, @@ -26,6 +25,7 @@ import { SessionStaticKey, SessionStaticSecret, ThresholdDecryptionResponse, + ThresholdMessageKit, Transcript, Validator, ValidatorMessage, @@ -36,13 +36,15 @@ import axios from 'axios'; import { ethers, providers, Wallet } from 'ethers'; import { keccak256 } from 'ethers/lib/utils'; -import { Alice, Bob, Cohort, RemoteBob } from '../src'; +import { Alice, Bob, Cohort, Enrico, RemoteBob } from '../src'; import { DkgCoordinatorAgent, DkgParticipant, DkgRitualState, } from '../src/agents/coordinator'; import { ThresholdDecrypter } from '../src/characters/cbd-recipient'; +import { ConditionExpression } from '../src/conditions'; +import { ERC721Balance } from '../src/conditions/predefined'; import { DkgClient, DkgRitual } from '../src/dkg'; import { BlockchainPolicy, PreEnactedPolicy } from '../src/policies/policy'; import { @@ -55,6 +57,8 @@ import { import { ChecksumAddress } from '../src/types'; import { toBytes, toHexString, zip } from '../src/utils'; +import { TEST_CHAIN_ID, TEST_CONTRACT_ADDR } from './unit/testVariables'; + export const bytesEqual = (first: Uint8Array, second: Uint8Array): boolean => first.length === second.length && first.every((value, index) => value === second[index]); @@ -287,23 +291,31 @@ interface FakeDkgRitualFlow { sharesNum: number; threshold: number; receivedMessages: ValidatorMessage[]; - ciphertext: Ciphertext; - aad: Uint8Array; dkg: Dkg; message: Uint8Array; + dkgPublicKey: DkgPublicKey; + conditionExpr: ConditionExpression; + thresholdMessageKit: ThresholdMessageKit; } -export const fakeTDecFlow = ({ +export const fakeTDecFlow = async ({ validators, validatorKeypairs, ritualId, sharesNum, threshold, receivedMessages, - ciphertext, - aad, message, + conditionExpr, + dkgPublicKey, + thresholdMessageKit, }: FakeDkgRitualFlow) => { + const [_ciphertext, authenticatedData] = encryptForDkg( + message, + dkgPublicKey, + conditionExpr.toWASMConditions() + ); + // Having aggregated the transcripts, the validators can now create decryption shares const decryptionShares: ( | DecryptionSharePrecomputed @@ -319,56 +331,71 @@ export const fakeTDecFlow = ({ const decryptionShare = aggregate.createDecryptionShareSimple( dkg, - ciphertext.header, - aad, + thresholdMessageKit.ciphertextHeader, + authenticatedData.aad(), keypair ); decryptionShares.push(decryptionShare); }); - // Now, the decryption share can be used to decrypt the ciphertext - // This part is in the client API const sharedSecret = combineDecryptionSharesSimple(decryptionShares); - // The client should have access to the public parameters of the DKG - const plaintext = decryptWithSharedSecret(ciphertext, aad, sharedSecret); + const plaintext = thresholdMessageKit.decryptWithSharedSecret(sharedSecret); if (!bytesEqual(plaintext, message)) { throw new Error('Decryption failed'); } - return { decryptionShares, sharedSecret, plaintext }; + return { + authenticatedData, + decryptionShares, + plaintext, + sharedSecret, + thresholdMessageKit, + }; }; -export const fakeDkgTDecFlowE2e = ( - variant: FerveoVariant, - message = toBytes('fake-message'), - aad = toBytes('fake-aad'), +const fakeConditionExpr = () => { + const erc721Balance = new ERC721Balance({ + chain: TEST_CHAIN_ID, + contractAddress: TEST_CONTRACT_ADDR, + }); + return new ConditionExpression(erc721Balance); +}; + +export const fakeDkgTDecFlowE2E = async ( ritualId = 0, + variant: FerveoVariant = FerveoVariant.precomputed, + conditionExpr: ConditionExpression = fakeConditionExpr(), + message = toBytes('fake-message'), sharesNum = 4, threshold = 4 ) => { const ritual = fakeDkgFlow(variant, ritualId, sharesNum, threshold); + const dkgPublicKey = ritual.dkg.publicKey(); + const thresholdMessageKit = new Enrico(dkgPublicKey).encryptMessageCbd( + message, + conditionExpr + ); - // In the meantime, the client creates a ciphertext and decryption request - const ciphertext = ferveoEncrypt(message, aad, ritual.dkg.publicKey()); - const { decryptionShares } = fakeTDecFlow({ + const { decryptionShares, authenticatedData } = await fakeTDecFlow({ ...ritual, - ciphertext, - aad, message, + conditionExpr, + dkgPublicKey, + thresholdMessageKit, }); return { ...ritual, message, - aad, - ciphertext, decryptionShares, + authenticatedData, + thresholdMessageKit, }; }; -export const fakeCoordinatorRitual = ( +export const fakeCoordinatorRitual = async ( ritualId: number -): { +): Promise<{ aggregationMismatch: boolean; initTimestamp: number; aggregatedTranscriptHash: string; @@ -380,8 +407,8 @@ export const fakeCoordinatorRitual = ( aggregatedTranscript: string; publicKeyHash: string; totalAggregations: number; -} => { - const ritual = fakeDkgTDecFlowE2e(FerveoVariant.precomputed); +}> => { + const ritual = await fakeDkgTDecFlowE2E(); const dkgPkBytes = ritual.dkg.publicKey().toBytes(); return { id: ritualId, @@ -401,14 +428,13 @@ export const fakeCoordinatorRitual = ( }; }; -export const fakeDkgParticipants = ( - ritualId: number, - variant = FerveoVariant.precomputed -): { +export const fakeDkgParticipants = async ( + ritualId: number +): Promise<{ participants: DkgParticipant[]; participantSecrets: Record; -} => { - const ritual = fakeDkgTDecFlowE2e(variant); +}> => { + const ritual = await fakeDkgTDecFlowE2E(ritualId); const label = toBytes(`${ritualId}`); const participantSecrets: Record = @@ -497,12 +523,6 @@ export const fakeDkgRitual = (ritual: { ); }; -export const mockInitializeRitual = (ritualId: number) => { - return jest.spyOn(DkgClient, 'initializeRitual').mockImplementation(() => { - return Promise.resolve(ritualId); - }); -}; - export const mockGetExistingRitual = (dkgRitual: DkgRitual) => { return jest.spyOn(DkgClient, 'getExistingRitual').mockImplementation(() => { return Promise.resolve(dkgRitual); @@ -517,9 +537,3 @@ export const makeCohort = async (ursulas: Ursula[]) => { expect(getUrsulasSpy).toHaveBeenCalled(); return cohort; }; - -export const mockGetRitualState = (state = DkgRitualState.FINALIZED) => { - return jest - .spyOn(DkgCoordinatorAgent, 'getRitualState') - .mockImplementation((_provider, _ritualId) => Promise.resolve(state)); -};