Skip to content

Commit

Permalink
feat(connect): try to reuse initialized session
Browse files Browse the repository at this point in the history
DeviceState contains all the information to call `Initialize` message
  • Loading branch information
szymonlesisz authored and mroz22 committed Oct 21, 2024
1 parent bdac1d3 commit 422476c
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 25 deletions.
37 changes: 36 additions & 1 deletion packages/connect/e2e/tests/device/keepSession.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import TrezorConnect from '../../../src';
import TrezorConnect, { StaticSessionId } from '../../../src';

import { getController, setup, conditionalTest, initTrezorConnect } from '../../common.setup';

Expand All @@ -9,6 +9,7 @@ describe('keepSession common param', () => {
await TrezorConnect.dispose();
await setup(controller, {
mnemonic: 'mnemonic_all',
passphrase_protection: true,
});
await initTrezorConnect(controller);
});
Expand All @@ -19,6 +20,10 @@ describe('keepSession common param', () => {
});

conditionalTest(['1', '<2.3.2'], 'keepSession with changing useCardanoDerivation', async () => {
TrezorConnect.on('ui-request_passphrase', () => {
TrezorConnect.uiResponse({ type: 'ui-receive_passphrase', payload: { value: 'a' } });
});

const noDerivation = await TrezorConnect.getAccountDescriptor({
coin: 'ada',
path: "m/1852'/1815'/0'/0/0",
Expand All @@ -38,5 +43,35 @@ describe('keepSession common param', () => {
});
if (!enableDerivation.success) throw new Error(enableDerivation.payload.error);
expect(enableDerivation.payload.descriptor).toBeDefined();

const { device } = enableDerivation;
if (!device || !device.state) throw new Error('Device not found');

// change device instance to simulate app reload
// passphrase request should not be called
TrezorConnect.removeAllListeners('ui-request_passphrase');
// modify instance in staticSessionId
const staticSessionId = device.state.staticSessionId?.replace(
':0',
':1',
) as StaticSessionId;
const keepCardanoDerivation = await TrezorConnect.getAccountDescriptor({
coin: 'ada',
path: "m/1852'/1815'/0'/0/0",
device: {
// change instance to new but use already initialized state
instance: 1,
state: {
...device.state,
staticSessionId,
},
path: device.path,
},
// useCardanoDerivation: true, // NOTE: not required, its in the state
});
if (!keepCardanoDerivation.success) throw new Error(keepCardanoDerivation.payload.error);
expect(keepCardanoDerivation.payload.descriptor).toEqual(
enableDerivation.payload.descriptor,
);
});
});
39 changes: 25 additions & 14 deletions packages/connect/src/core/AbstractMethod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,30 @@ function validateStaticSessionId(input: unknown): StaticSessionId {
'DeviceState: invalid staticSessionId: ' + input,
);
}
// validate expected state from method parameter.
// it could be undefined
function validateDeviceState(input: unknown): DeviceState | undefined {
if (typeof input === 'string') {
return { staticSessionId: validateStaticSessionId(input) };
}
if (input && typeof input === 'object') {
const state: DeviceState = {};
if ('staticSessionId' in input) {
state.staticSessionId = validateStaticSessionId(input.staticSessionId);
}
if ('sessionId' in input && typeof input.sessionId === 'string') {
state.sessionId = input.sessionId;
}
if ('deriveCardano' in input && typeof input.deriveCardano === 'boolean') {
state.deriveCardano = input.deriveCardano;
}

return state;
}

return undefined;
}

export abstract class AbstractMethod<Name extends CallMethodPayload['method'], Params = undefined> {
responseID: number;

Expand Down Expand Up @@ -131,20 +155,7 @@ export abstract class AbstractMethod<Name extends CallMethodPayload['method'], P
this.payload = payload;
this.responseID = message.id || 0;
this.devicePath = payload.device?.path;

// expected state from method parameter.
// it could be undefined
this.deviceState =
// eslint-disable-next-line no-nested-ternary
typeof payload.device?.state === 'string'
? { staticSessionId: validateStaticSessionId(payload.device.state) }
: payload.device?.state?.staticSessionId
? {
staticSessionId: validateStaticSessionId(
payload.device.state.staticSessionId,
),
}
: undefined;
this.deviceState = validateDeviceState(payload.device?.state);
this.hasExpectedDeviceState = payload.device
? Object.prototype.hasOwnProperty.call(payload.device, 'state')
: false;
Expand Down
16 changes: 8 additions & 8 deletions packages/connect/src/device/Device.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,6 @@ export class Device extends TypedEmitter<DeviceEvents> {
firmwareHash: null,
};

private useCardanoDerivation = false;

constructor(transport: Transport, descriptor: Descriptor) {
super();

Expand Down Expand Up @@ -366,11 +364,12 @@ export class Device extends TypedEmitter<DeviceEvents> {
await this.releasePromise;
}

const { staticSessionId, deriveCardano } = this.getState() || {};
if (
!this.isUsedHere() ||
this.commands?.disposed ||
!this.getState()?.staticSessionId ||
this.useCardanoDerivation != !!options.useCardanoDerivation
!staticSessionId ||
(!deriveCardano && options.useCardanoDerivation)
) {
// acquire session
await this.acquire();
Expand Down Expand Up @@ -549,19 +548,20 @@ export class Device extends TypedEmitter<DeviceEvents> {
async initialize(useCardanoDerivation: boolean) {
let payload: PROTO.Initialize | undefined;
if (this.features) {
const sessionId = this.getState()?.sessionId;
payload = {};
const { sessionId, deriveCardano } = this.getState() || {};
// If the user has BIP-39 seed, and Initialize(derive_cardano=True) is not sent,
// all Cardano calls will fail because the root secret will not be available.
payload.derive_cardano = useCardanoDerivation;
this.useCardanoDerivation = useCardanoDerivation;
payload = {
derive_cardano: deriveCardano || useCardanoDerivation,
};
if (sessionId) {
payload.session_id = sessionId;
}
}

const { message } = await this.getCommands().typedCall('Initialize', 'Features', payload);
this._updateFeatures(message);
this.setState({ deriveCardano: payload?.derive_cardano });
}

initStorage(storage: IStateStorage) {
Expand Down
1 change: 1 addition & 0 deletions packages/connect/src/types/device.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export type DeviceState = {
sessionId?: string; // dynamic value: Features.session_id
// ${first testnet address}@${device.features.device_id}:${device.instance}
staticSessionId?: StaticSessionId;
deriveCardano?: boolean;
};

// NOTE: unavailableCapabilities is an object with information what is NOT supported by this device.
Expand Down
4 changes: 2 additions & 2 deletions packages/connect/src/types/params.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import { ErrorCode } from '../constants/errors';

export interface DeviceIdentity {
path?: string;
state?: string | DeviceState;
state?: DeviceState;
instance?: number;
}

export interface CommonParams {
device?: DeviceIdentity;
device?: DeviceIdentity & { state?: DeviceState | string }; // Note: state as string should be removed https://github.com/trezor/trezor-suite/issues/12710
useEmptyPassphrase?: boolean;
useEventListener?: boolean; // this param is set automatically in factory
allowSeedlessDevice?: boolean;
Expand Down

0 comments on commit 422476c

Please sign in to comment.