Skip to content

Commit

Permalink
Merge pull request #678 from magiclabs/split-key-device-share
Browse files Browse the repository at this point in the history
Split key device share
  • Loading branch information
Dizigen authored Dec 14, 2023
2 parents 7bc54cc + 48b1ddb commit 82f1080
Show file tree
Hide file tree
Showing 11 changed files with 386 additions and 20 deletions.
25 changes: 24 additions & 1 deletion packages/@magic-sdk/provider/src/core/sdk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,27 @@ function checkExtensionCompat(ext: Extension<string>) {
return true;
}

/**
* Generates a network hash of the SDK instance for persisting network specific
* information on multichain setups
*/
function getNetworkHash(apiKey: string, network?: EthNetworkConfiguration, extConfig?: any) {
if (!network && !extConfig) {
return `${apiKey}_eth_mainnet`;
}
if (extConfig) {
return `${apiKey}_${JSON.stringify(extConfig)}`;
}
if (network) {
if (typeof network === 'string') {
return `${apiKey}_eth_${network}`;
}
// Custom network, not necessarily eth.
return `${apiKey}_${network.rpcUrl}_${network.chainId}_${network.chainType}`;
}
return `${apiKey}_unknown`;
}

/**
* Initializes SDK extensions, checks for platform/version compatiblity issues,
* then consolidates any global configurations provided by those extensions.
Expand Down Expand Up @@ -103,6 +124,7 @@ export class SDKBase {

protected readonly endpoint: string;
protected readonly parameters: string;
protected readonly networkHash: string;
public readonly testMode: boolean;

/**
Expand Down Expand Up @@ -169,6 +191,7 @@ export class SDKBase {
locale: options?.locale || 'en_US',
...(SDKEnvironment.bundleId ? { bundleId: SDKEnvironment.bundleId } : {}),
});
this.networkHash = getNetworkHash(this.apiKey, options?.network, isEmpty(extConfig) ? undefined : extConfig);
if (!options?.deferPreload) this.preload();
}

Expand All @@ -177,7 +200,7 @@ export class SDKBase {
*/
protected get overlay(): ViewController {
if (!SDKBase.__overlays__.has(this.parameters)) {
const controller = new SDKEnvironment.ViewController(this.endpoint, this.parameters);
const controller = new SDKEnvironment.ViewController(this.endpoint, this.parameters, this.networkHash);

// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore - We don't want to expose this method to the user, but we
Expand Down
53 changes: 44 additions & 9 deletions packages/@magic-sdk/provider/src/core/view-controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ import { getItem, setItem } from '../util/storage';
import { createJwt } from '../util/web-crypto';
import { SDKEnvironment } from './sdk-environment';
import { createModalNotReadyError } from './sdk-exceptions';
import {
clearDeviceShares,
encryptAndPersistDeviceShare,
getDecryptedDeviceShare,
} from '../util/device-share-web-crypto';

interface RemoveEventListenerFunction {
(): void;
Expand All @@ -21,6 +26,14 @@ interface StandardizedResponse {
response?: JsonRpcResponse;
}

interface StandardizedMagicRequest {
msgType: string;
payload: JsonRpcRequestPayload<any> | JsonRpcRequestPayload<any>[];
jwt?: string;
rt?: string;
deviceShare?: string;
}

/**
* Get the originating payload from a batch request using the specified `id`.
*/
Expand Down Expand Up @@ -56,7 +69,11 @@ function standardizeResponse(
return {};
}

async function createMagicRequest(msgType: string, payload: JsonRpcRequestPayload | JsonRpcRequestPayload[]) {
async function createMagicRequest(
msgType: string,
payload: JsonRpcRequestPayload | JsonRpcRequestPayload[],
networkHash: string,
) {
const rt = await getItem<string>('rt');
let jwt;

Expand All @@ -69,15 +86,22 @@ async function createMagicRequest(msgType: string, payload: JsonRpcRequestPayloa
}
}

if (!jwt) {
return { msgType, payload };
const request: StandardizedMagicRequest = { msgType, payload };

if (jwt) {
request.jwt = jwt;
}
if (jwt && rt) {
request.rt = rt;
}

if (!rt) {
return { msgType, payload, jwt };
// Grab the device share if it exists for the network
const decryptedDeviceShare = await getDecryptedDeviceShare(networkHash);
if (decryptedDeviceShare) {
request.deviceShare = decryptedDeviceShare;
}

return { msgType, payload, jwt, rt };
return request;
}

async function persistMagicEventRefreshToken(event: MagicMessageEvent) {
Expand All @@ -99,8 +123,14 @@ export abstract class ViewController {
* @param endpoint - The URL for the relevant iframe context.
* @param parameters - The unique, encoded query parameters for the
* relevant iframe context.
* @param networkHash - The hash of the network that this sdk instance is connected to
* for multi-chain scenarios
*/
constructor(protected readonly endpoint: string, protected readonly parameters: string) {
constructor(
protected readonly endpoint: string,
protected readonly parameters: string,
protected readonly networkHash: string,
) {
this.checkIsReadyForRequest = this.waitForReady();
this.listen();
}
Expand Down Expand Up @@ -141,7 +171,7 @@ export abstract class ViewController {

const batchData: JsonRpcResponse[] = [];
const batchIds = Array.isArray(payload) ? payload.map((p) => p.id) : [];
const msg = await createMagicRequest(`${msgType}-${this.parameters}`, payload);
const msg = await createMagicRequest(`${msgType}-${this.parameters}`, payload, this.networkHash);

await this._post(msg);

Expand All @@ -151,7 +181,12 @@ export abstract class ViewController {
const acknowledgeResponse = (removeEventListener: RemoveEventListenerFunction) => (event: MagicMessageEvent) => {
const { id, response } = standardizeResponse(payload, event);
persistMagicEventRefreshToken(event);

if (response?.payload.error?.message === 'User denied account access.') {
clearDeviceShares();
} else if (event.data.deviceShare) {
const { deviceShare } = event.data;
encryptAndPersistDeviceShare(deviceShare, this.networkHash);
}
if (id && response && Array.isArray(payload) && batchIds.includes(id)) {
batchData.push(response);

Expand Down
2 changes: 2 additions & 0 deletions packages/@magic-sdk/provider/src/modules/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { BaseModule } from './base-module';
import { createJsonRpcRequestPayload } from '../core/json-rpc';
import { createDeprecationWarning } from '../core/sdk-exceptions';
import { ProductConsolidationMethodRemovalVersions } from './auth';
import { clearDeviceShares } from '../util/device-share-web-crypto';

export type UpdateEmailEvents = {
'email-sent': () => void;
Expand Down Expand Up @@ -54,6 +55,7 @@ export class UserModule extends BaseModule {

public logout() {
removeItem(this.localForageKey);
clearDeviceShares();
const requestPayload = createJsonRpcRequestPayload(
this.sdk.testMode ? MagicPayloadMethod.LogoutTestMode : MagicPayloadMethod.Logout,
);
Expand Down
2 changes: 2 additions & 0 deletions packages/@magic-sdk/provider/src/modules/wallet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { createDeprecationWarning } from '../core/sdk-exceptions';
import { setItem, getItem, removeItem } from '../util/storage';
import { ProductConsolidationMethodRemovalVersions } from './auth';
import { createPromiEvent } from '../util';
import { clearDeviceShares } from '../util/device-share-web-crypto';

export type ConnectWithUiEvents = {
'id-token-created': (params: { idToken: string }) => void;
Expand Down Expand Up @@ -111,6 +112,7 @@ export class WalletModule extends BaseModule {
useInstead: 'user.logout()',
}).log();
removeItem(this.localForageKey);
clearDeviceShares();
const requestPayload = createJsonRpcRequestPayload(MagicPayloadMethod.Disconnect);
return this.request<boolean>(requestPayload);
}
Expand Down
120 changes: 120 additions & 0 deletions packages/@magic-sdk/provider/src/util/device-share-web-crypto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import { getItem, iterate, removeItem, setItem } from './storage';
import { isWebCryptoSupported } from './web-crypto';

export const DEVICE_SHARE_KEY = 'ds';
export const ENCRYPTION_KEY_KEY = 'ek';
export const INITIALIZATION_VECTOR_KEY = 'iv';

const ALGO_NAME = 'AES-GCM'; // for encryption
const ALGO_LENGTH = 256;

export async function clearDeviceShares() {
const keysToRemove: string[] = [];
await iterate((value, key, iterationNumber) => {
if (key.startsWith(`${DEVICE_SHARE_KEY}_`)) {
keysToRemove.push(key);
}
});
for (const key of keysToRemove) {
// eslint-disable-next-line no-await-in-loop
await removeItem(key);
}
}

function arrayBufferToBase64(buffer: ArrayBuffer) {
let binary = '';
const bytes = new Uint8Array(buffer);
const len = bytes.byteLength;
for (let i = 0; i < len; i++) {
binary += String.fromCharCode(bytes[i]);
}
return window.btoa(binary);
}

export function base64ToArrayBuffer(base64: string) {
const binaryString = window.atob(base64);
const len = binaryString.length;
const bytes = new Uint8Array(len);
for (let i = 0; i < len; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
return bytes.buffer;
}

async function getOrCreateInitVector() {
if (!isWebCryptoSupported()) {
console.info('webcrypto is not supported');
return undefined;
}
const { crypto } = window;
const existingIv = (await getItem(INITIALIZATION_VECTOR_KEY)) as Uint8Array;
if (existingIv) {
return existingIv;
}

const iv = crypto.getRandomValues(new Uint8Array(12)); // 12 bytes for AES-GCM
return iv;
}

async function getOrCreateEncryptionKey() {
if (!isWebCryptoSupported()) {
console.info('webcrypto is not supported');
return undefined;
}
const { subtle } = window.crypto;
const existingKey = (await getItem(ENCRYPTION_KEY_KEY)) as CryptoKey;
if (existingKey) {
return existingKey;
}

const key = await subtle.generateKey(
{ name: ALGO_NAME, length: ALGO_LENGTH },
false, // non-extractable
['encrypt', 'decrypt'],
);
return key;
}

export async function encryptAndPersistDeviceShare(deviceShareBase64: string, networkHash: string): Promise<void> {
const iv = await getOrCreateInitVector();
const encryptionKey = await getOrCreateEncryptionKey();

if (!iv || !encryptionKey || !deviceShareBase64) {
return;
}
const decodedDeviceShare = base64ToArrayBuffer(deviceShareBase64);

const { subtle } = window.crypto;

const encryptedData = await subtle.encrypt(
{
name: ALGO_NAME,
iv,
},
encryptionKey,
decodedDeviceShare,
);

// The encrypted device share we store is a base64 encoded string representation
// of the magic kms encrypted client share encrypted with webcrypto
const encryptedDeviceShare = arrayBufferToBase64(encryptedData);

await setItem(`${DEVICE_SHARE_KEY}_${networkHash}`, encryptedDeviceShare);
await setItem(ENCRYPTION_KEY_KEY, encryptionKey);
await setItem(INITIALIZATION_VECTOR_KEY, iv);
}

export async function getDecryptedDeviceShare(networkHash: string): Promise<string | undefined> {
const encryptedDeviceShare = await getItem<string>(`${DEVICE_SHARE_KEY}_${networkHash}`);
const iv = (await getItem(INITIALIZATION_VECTOR_KEY)) as Uint8Array; // use existing encryption key and initialization vector
const encryptionKey = (await getItem(ENCRYPTION_KEY_KEY)) as CryptoKey;

if (!iv || !encryptedDeviceShare || !encryptionKey || !isWebCryptoSupported()) {
return undefined;
}

const { subtle } = window.crypto;
const ab = await subtle.decrypt({ name: ALGO_NAME, iv }, encryptionKey, base64ToArrayBuffer(encryptedDeviceShare));

return arrayBufferToBase64(ab);
}
14 changes: 7 additions & 7 deletions packages/@magic-sdk/provider/src/util/web-crypto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ const EC_GEN_PARAMS: EcKeyGenParams = {
namedCurve: ALGO_CURVE,
};

export function isWebCryptoSupported() {
const hasCrypto = typeof window !== 'undefined' && !!(window.crypto as any);
const hasSubtleCrypto = hasCrypto && !!(window.crypto.subtle as any);

return hasCrypto && hasSubtleCrypto;
}

export function clearKeys() {
removeItem(STORE_KEY_PUBLIC_JWK);
removeItem(STORE_KEY_PRIVATE_KEY);
Expand Down Expand Up @@ -87,13 +94,6 @@ async function generateWCKP() {
await setItem(STORE_KEY_PUBLIC_JWK, jwkPublicKey);
}

function isWebCryptoSupported() {
const hasCrypto = typeof window !== 'undefined' && !!(window.crypto as any);
const hasSubtleCrypto = hasCrypto && !!(window.crypto.subtle as any);

return hasCrypto && hasSubtleCrypto;
}

function strToUrlBase64(str: string) {
return binToUrlBase64(utf8ToBinaryString(str));
}
Expand Down
1 change: 1 addition & 0 deletions packages/@magic-sdk/provider/test/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export const MAGIC_RELAYER_FULL_URL = 'https://auth.magic.link';
export const TEST_API_KEY = 'pk_test_123';
export const LIVE_API_KEY = 'pk_live_123';
export const ENCODED_QUERY_PARAMS = 'testqueryparams';
export const TEST_NETWORK_HASH = 'eth_mainnet';
export const MSG_TYPES = (parameters = ENCODED_QUERY_PARAMS) => ({
MAGIC_HANDLE_RESPONSE: `MAGIC_HANDLE_RESPONSE-${parameters}`,
MAGIC_OVERLAY_READY: `MAGIC_OVERLAY_READY-${parameters}`,
Expand Down
4 changes: 2 additions & 2 deletions packages/@magic-sdk/provider/test/factories.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as memoryDriver from 'localforage-driver-memory';
import localForage from 'localforage';
import { MAGIC_RELAYER_FULL_URL, ENCODED_QUERY_PARAMS, TEST_API_KEY } from './constants';
import { MAGIC_RELAYER_FULL_URL, ENCODED_QUERY_PARAMS, TEST_API_KEY, TEST_NETWORK_HASH } from './constants';
import { ViewController } from '../src';
import type { SDKEnvironment } from '../src/core/sdk-environment';

Expand All @@ -27,7 +27,7 @@ export class TestViewController extends ViewController {
}

export function createViewController(endpoint = MAGIC_RELAYER_FULL_URL) {
const viewController = new TestViewController(endpoint, ENCODED_QUERY_PARAMS);
const viewController = new TestViewController(endpoint, ENCODED_QUERY_PARAMS, TEST_NETWORK_HASH);
viewController.init();
return viewController;
}
Expand Down
Loading

0 comments on commit 82f1080

Please sign in to comment.