diff --git a/programs/drift/src/instructions/keeper.rs b/programs/drift/src/instructions/keeper.rs index 8685c6380..16fd76375 100644 --- a/programs/drift/src/instructions/keeper.rs +++ b/programs/drift/src/instructions/keeper.rs @@ -61,11 +61,11 @@ use crate::state::user::{ MarginMode, MarketType, OrderStatus, OrderTriggerCondition, OrderType, User, UserStats, }; use crate::state::user_map::{load_user_map, load_user_maps, UserMap, UserStatsMap}; -use crate::validation::sig_verification::{extract_ed25519_ix_signature, verify_ed25519_msg}; +use crate::validation::sig_verification::verify_ed25519_msg; use crate::validation::user::{validate_user_deletion, validate_user_is_idle}; use crate::{ - controller, digest_struct, digest_struct_hex, load, math, print_error, safe_decrement, - OracleSource, GOV_SPOT_MARKET_INDEX, MARGIN_PRECISION, + controller, load, math, print_error, safe_decrement, OracleSource, GOV_SPOT_MARKET_INDEX, + MARGIN_PRECISION, }; use crate::{load_mut, QUOTE_PRECISION_U64}; use crate::{validate, QUOTE_PRECISION_I128}; @@ -603,9 +603,6 @@ pub fn handle_place_swift_taker_order<'c: 'info, 'info>( ctx: Context<'_, '_, 'c, 'info, PlaceSwiftTakerOrder<'info>>, swift_order_params_message_bytes: Vec, ) -> Result<()> { - let taker_order_params_message: SwiftOrderParamsMessage = - SwiftOrderParamsMessage::deserialize(&mut &swift_order_params_message_bytes[..]).unwrap(); - let state = &ctx.accounts.state; // TODO: generalize to support multiple market types @@ -629,7 +626,7 @@ pub fn handle_place_swift_taker_order<'c: 'info, 'info>( taker_key, &mut taker, &mut swift_taker, - taker_order_params_message, + swift_order_params_message_bytes, &ctx.accounts.ix_sysvar.to_account_info(), &perp_market_map, &spot_market_map, @@ -643,7 +640,7 @@ pub fn place_swift_taker_order<'c: 'info, 'info>( taker_key: Pubkey, taker: &mut RefMut, swift_account: &mut SwiftUserOrdersZeroCopyMut, - taker_order_params_message: SwiftOrderParamsMessage, + taker_order_params_message_bytes: Vec, ix_sysvar: &AccountInfo<'info>, perp_market_map: &PerpMarketMap, spot_market_map: &SpotMarketMap, @@ -664,15 +661,19 @@ pub fn place_swift_taker_order<'c: 'info, 'info>( )?; // Verify data from verify ix - let digest_hex = digest_struct_hex!(taker_order_params_message); let ix: Instruction = load_instruction_at_checked(ix_idx as usize - 1, ix_sysvar)?; - verify_ed25519_msg( + let verified_message_and_signature = verify_ed25519_msg( &ix, + ix_idx, &taker.authority.to_bytes(), - arrayref::array_ref!(digest_hex, 0, 64), + &taker_order_params_message_bytes[..], + 12, )?; - let signature = extract_ed25519_ix_signature(&ix.data)?; + let taker_order_params_message: SwiftOrderParamsMessage = + verified_message_and_signature.swift_order_params_message; + + let signature = verified_message_and_signature.signature; let clock = &Clock::get()?; // First order must be a taker order diff --git a/programs/drift/src/validation/sig_verification.rs b/programs/drift/src/validation/sig_verification.rs index 7a840d572..574dccad4 100644 --- a/programs/drift/src/validation/sig_verification.rs +++ b/programs/drift/src/validation/sig_verification.rs @@ -1,9 +1,45 @@ use crate::error::ErrorCode; +use crate::state::order_params::SwiftOrderParamsMessage; use anchor_lang::prelude::*; +use bytemuck::try_cast_slice; +use bytemuck::{Pod, Zeroable}; +use byteorder::ByteOrder; +use byteorder::LE; use solana_program::ed25519_program::ID as ED25519_ID; use solana_program::instruction::Instruction; use std::convert::TryInto; +const ED25519_PROGRAM_INPUT_HEADER_LEN: usize = 2; + +const SIGNATURE_LEN: u16 = 64; +const PUBKEY_LEN: u16 = 32; +const MESSAGE_SIZE_LEN: u16 = 2; + +/// Part of the inputs to the built-in `ed25519_program` on Solana that represents a single +/// signature verification request. +/// +/// `ed25519_program` does not receive the signature data directly. Instead, it receives +/// these fields that indicate the location of the signature data within data of other +/// instructions within the same transaction. +#[derive(Debug, Clone, Copy, Zeroable, Pod)] +#[repr(C)] +pub struct Ed25519SignatureOffsets { + /// Offset to the ed25519 signature within the instruction data. + pub signature_offset: u16, + /// Index of the instruction that contains the signature. + pub signature_instruction_index: u16, + /// Offset to the public key within the instruction data. + pub public_key_offset: u16, + /// Index of the instruction that contains the public key. + pub public_key_instruction_index: u16, + /// Offset to the signed payload within the instruction data. + pub message_data_offset: u16, + // Size of the signed payload. + pub message_data_size: u16, + /// Index of the instruction that contains the signed payload. + pub message_instruction_index: u16, +} + /// Verify Ed25519Program instruction fields pub fn verify_ed25519_ix(ix: &Instruction, pubkey: &[u8], msg: &[u8], sig: &[u8]) -> Result<()> { if ix.program_id != ED25519_ID || // The program id we expect @@ -79,6 +115,11 @@ fn check_ed25519_data(data: &[u8], pubkey: &[u8], msg: &[u8], sig: &[u8]) -> Res Ok(()) } +pub struct VerifiedMessage { + pub swift_order_params_message: SwiftOrderParamsMessage, + pub signature: [u8; 64], +} + /// Check Ed25519Program instruction data verifies the given msg /// /// `ix` an Ed25519Program instruction [see](https://github.com/solana-labs/solana/blob/master/sdk/src/ed25519_instruction.rs)) @@ -87,11 +128,13 @@ fn check_ed25519_data(data: &[u8], pubkey: &[u8], msg: &[u8], sig: &[u8]) -> Res /// /// `pubkey` expected pubkey of the signer /// -pub fn verify_ed25519_msg( +pub fn verify_ed25519_msg( ix: &Instruction, - pubkey: &[u8; 32], - msg: &[u8; N], -) -> Result<()> { + current_ix_index: u16, + signer: &[u8; 32], + msg: &[u8], + message_offset: u16, +) -> Result { if ix.program_id != ED25519_ID || ix.accounts.len() != 0 { msg!("Invalid Ix: program ID: {:?}", ix.program_id); msg!("Invalid Ix: accounts: {:?}", ix.accounts.len()); @@ -100,72 +143,143 @@ pub fn verify_ed25519_msg( let ix_data = &ix.data; // According to this layout used by the Ed25519Program] - if ix_data.len() <= 112 { + if ix_data.len() < 2 { msg!( - "Invalid Ix: data: {:?}, len: {:?}", + "Invalid Ix, should be header len = 2. data: {:?}", ix.data.len(), - 16 + 64 + 32 + N ); - return Err(ErrorCode::SigVerificationFailed.into()); + return Err(SignatureVerificationError::InvalidEd25519InstructionDataLength.into()); } - // Check the ed25519 verify ix header is sound - let num_signatures = ix_data[0]; - let padding = ix_data[1]; - let signature_offset = u16::from_le_bytes(ix_data[2..=3].try_into().unwrap()); - let signature_instruction_index = u16::from_le_bytes(ix_data[4..=5].try_into().unwrap()); - let public_key_offset = u16::from_le_bytes(ix_data[6..=7].try_into().unwrap()); - let public_key_instruction_index = u16::from_le_bytes(ix_data[8..=9].try_into().unwrap()); - let message_data_offset = u16::from_le_bytes(ix_data[10..=11].try_into().unwrap()); - let message_data_size = u16::from_le_bytes(ix_data[12..=13].try_into().unwrap()); - let message_instruction_index = u16::from_le_bytes(ix_data[14..=15].try_into().unwrap()); + // Parse the ix data into the offsets + let args: &[Ed25519SignatureOffsets] = + try_cast_slice(&ix_data[ED25519_PROGRAM_INPUT_HEADER_LEN..]).map_err(|_| { + msg!("Invalid Ix: failed to cast slice"); + ErrorCode::SigVerificationFailed + })?; - // Expected values - let exp_public_key_offset: u16 = 16; - let exp_signature_offset: u16 = exp_public_key_offset + 32_u16; - let exp_message_data_offset: u16 = exp_signature_offset + 64_u16; - let exp_num_signatures: u8 = 1; - - // Header - if num_signatures != exp_num_signatures - || padding != 0 - || signature_offset != exp_signature_offset - || signature_instruction_index != u16::MAX - || public_key_offset != exp_public_key_offset - || public_key_instruction_index != u16::MAX - || message_data_offset != exp_message_data_offset - || message_instruction_index != u16::MAX - { + let offsets = &args[0]; + if offsets.signature_offset != message_offset { + msg!( + "Invalid Ix: signature offset: {:?}", + offsets.signature_offset + ); return Err(ErrorCode::SigVerificationFailed.into()); } - // verify data is for digest and pubkey - let ix_msg_data = &ix_data[112..]; - if ix_msg_data != msg || message_data_size != N as u16 { + let expected_public_key_offset = message_offset + .checked_add(SIGNATURE_LEN) + .ok_or(ErrorCode::SigVerificationFailed)?; + if offsets.public_key_offset != expected_public_key_offset { + msg!( + "Invalid Ix: public key offset: {:?}, expected: {:?}", + offsets.public_key_offset, + expected_public_key_offset + ); return Err(ErrorCode::SigVerificationFailed.into()); } - let ix_pubkey = &ix_data[16..16 + 32]; - if ix_pubkey != pubkey { - msg!("Invalid Ix: pubkey: {:?}", ix_pubkey); - msg!("Invalid Ix: expected pubkey: {:?}", pubkey); - return Err(ErrorCode::SigVerificationFailed.into()); + let expected_message_size_offset = expected_public_key_offset + .checked_add(PUBKEY_LEN) + .ok_or(ErrorCode::SigVerificationFailed)?; + + let expected_message_data_offset = expected_message_size_offset + .checked_add(MESSAGE_SIZE_LEN) + .ok_or(SignatureVerificationError::MessageOffsetOverflow)?; + if offsets.message_data_offset != expected_message_data_offset { + return Err(SignatureVerificationError::InvalidMessageOffset.into()); } - Ok(()) -} + let expected_message_size: u16 = { + let start = usize::from( + expected_message_size_offset + .checked_sub(message_offset) + .unwrap(), + ); + let end = usize::from( + expected_message_data_offset + .checked_sub(message_offset) + .unwrap(), + ); + LE::read_u16(&msg[start..end]) + }; + if offsets.message_data_size != expected_message_size { + return Err(SignatureVerificationError::InvalidMessageDataSize.into()); + } + if offsets.signature_instruction_index != current_ix_index + || offsets.public_key_instruction_index != current_ix_index + || offsets.message_instruction_index != current_ix_index + { + return Err(SignatureVerificationError::InvalidInstructionIndex.into()); + } -/// Extract pubkey from serialized Ed25519Program instruction data -pub fn extract_ed25519_ix_pubkey(ix_data: &[u8]) -> Result<[u8; 32]> { - match ix_data[16..16 + 32].try_into() { - Ok(raw) => Ok(raw), - Err(_) => Err(ErrorCode::SigVerificationFailed.into()), + let public_key = { + let start = usize::from( + expected_public_key_offset + .checked_sub(message_offset) + .unwrap(), + ); + let end = start + .checked_add(anchor_lang::solana_program::pubkey::PUBKEY_BYTES) + .ok_or(SignatureVerificationError::MessageOffsetOverflow)?; + &msg[start..end] + }; + let mut payload = { + let start = usize::from( + expected_message_data_offset + .checked_sub(message_offset) + .unwrap(), + ); + let end = start + .checked_add(expected_message_size.into()) + .ok_or(SignatureVerificationError::MessageOffsetOverflow)?; + &msg[start..end] + }; + + if public_key != signer { + msg!("Invalid Ix: message signed by: {:?}", public_key); + msg!("Invalid Ix: expected pubkey: {:?}", signer); + return Err(ErrorCode::SigVerificationFailed.into()); } + + let signature = { + let start = usize::from( + offsets + .signature_offset + .checked_sub(message_offset) + .unwrap(), + ); + let end = start + .checked_add(SIGNATURE_LEN.into()) + .ok_or(SignatureVerificationError::InvalidSignatureOffset)?; + &msg[start..end].try_into().unwrap() + }; + + Ok(VerifiedMessage { + swift_order_params_message: SwiftOrderParamsMessage::deserialize(&mut payload).unwrap(), + signature: *signature, + }) } -pub fn extract_ed25519_ix_signature(ix_data: &[u8]) -> Result<[u8; 64]> { - match ix_data[48..48 + 64].try_into() { - Ok(raw) => Ok(raw), - Err(_) => Err(ErrorCode::SigVerificationFailed.into()), - } +#[error_code] +#[derive(PartialEq, Eq)] +pub enum SignatureVerificationError { + #[msg("invalid ed25519 instruction program")] + InvalidEd25519InstructionProgramId, + #[msg("invalid ed25519 instruction data length")] + InvalidEd25519InstructionDataLength, + #[msg("invalid signature index")] + InvalidSignatureIndex, + #[msg("invalid signature offset")] + InvalidSignatureOffset, + #[msg("invalid public key offset")] + InvalidPublicKeyOffset, + #[msg("invalid message offset")] + InvalidMessageOffset, + #[msg("invalid message data size")] + InvalidMessageDataSize, + #[msg("invalid instruction index")] + InvalidInstructionIndex, + #[msg("message offset overflow")] + MessageOffsetOverflow, } diff --git a/sdk/src/driftClient.ts b/sdk/src/driftClient.ts index fa19604d9..17adb35fa 100644 --- a/sdk/src/driftClient.ts +++ b/sdk/src/driftClient.ts @@ -179,7 +179,6 @@ import pythSolanaReceiverIdl from './idl/pyth_solana_receiver.json'; import { asV0Tx, PullFeed } from '@switchboard-xyz/on-demand'; import { gprcDriftClientAccountSubscriber } from './accounts/grpcDriftClientAccountSubscriber'; import nacl from 'tweetnacl'; -import { digest } from './util/digest'; import { Slothash } from './slot/SlothashSubscriber'; import { getOracleId } from './oracles/oracleId'; @@ -5886,9 +5885,7 @@ export class DriftClient { ): Buffer { const takerOrderParamsMessage = this.encodeSwiftOrderParamsMessage(orderParamsMessage); - return this.signMessage( - new TextEncoder().encode(digest(takerOrderParamsMessage).toString('hex')) - ); + return this.signMessage(takerOrderParamsMessage); } public encodeSwiftOrderParamsMessage( @@ -5925,13 +5922,18 @@ export class DriftClient { takerStats: PublicKey; takerUserAccount: UserAccount; }, + precedingIxs: TransactionInstruction[] = [], + overrideIxCount?: number, txParams?: TxParams ): Promise { const ixs = await this.getPlaceSwiftTakerPerpOrderIxs( swiftOrderParamsMessage, swiftOrderParamsSignature, marketIndex, - takerInfo + takerInfo, + undefined, + precedingIxs, + overrideIxCount ); const { txSig } = await this.sendTransaction( await this.buildTransaction(ixs, txParams), @@ -5950,7 +5952,9 @@ export class DriftClient { takerStats: PublicKey; takerUserAccount: UserAccount; }, - authority?: PublicKey + authority?: PublicKey, + precedingIxs: TransactionInstruction[] = [], + overrideIxCount?: number ): Promise { if (!authority && !takerInfo.takerUserAccount) { throw new Error('authority or takerUserAccount must be provided'); @@ -5963,33 +5967,39 @@ export class DriftClient { }); const authorityToUse = authority || takerInfo.takerUserAccount.authority; - const swiftOrderParamsSignatureIx = - Ed25519Program.createInstructionWithPublicKey({ - publicKey: authorityToUse.toBytes(), - signature: Uint8Array.from(swiftOrderParamsSignature), - message: new TextEncoder().encode( - digest(encodedSwiftOrderParamsMessage).toString('hex') - ), - }); + + const messageLengthBuffer = Buffer.alloc(2); + messageLengthBuffer.writeUInt16LE(encodedSwiftOrderParamsMessage.length); + + const swiftIxData = Buffer.concat([ + swiftOrderParamsSignature, + authorityToUse.toBytes(), + messageLengthBuffer, + encodedSwiftOrderParamsMessage, + ]); + + const swiftOrderParamsSignatureIx = createMinimalEd25519VerifyIx( + overrideIxCount || precedingIxs.length + 1, + 12, + swiftIxData, + 0 + ); const placeTakerSwiftPerpOrderIx = - await this.program.instruction.placeSwiftTakerOrder( - encodedSwiftOrderParamsMessage, - { - accounts: { - state: await this.getStatePublicKey(), - user: takerInfo.taker, - userStats: takerInfo.takerStats, - swiftUserOrders: getSwiftUserAccountPublicKey( - this.program.programId, - takerInfo.taker - ), - authority: this.wallet.publicKey, - ixSysvar: SYSVAR_INSTRUCTIONS_PUBKEY, - }, - remainingAccounts, - } - ); + this.program.instruction.placeSwiftTakerOrder(swiftIxData, { + accounts: { + state: await this.getStatePublicKey(), + user: takerInfo.taker, + userStats: takerInfo.takerStats, + swiftUserOrders: getSwiftUserAccountPublicKey( + this.program.programId, + takerInfo.taker + ), + authority: this.wallet.publicKey, + ixSysvar: SYSVAR_INSTRUCTIONS_PUBKEY, + }, + remainingAccounts, + }); return [swiftOrderParamsSignatureIx, placeTakerSwiftPerpOrderIx]; } @@ -6006,7 +6016,9 @@ export class DriftClient { orderParams: OptionalOrderParams, referrerInfo?: ReferrerInfo, txParams?: TxParams, - subAccountId?: number + subAccountId?: number, + precedingIxs: TransactionInstruction[] = [], + overrideIxCount?: number ): Promise { const ixs = await this.getPlaceAndMakeSwiftPerpOrderIxs( encodedSwiftOrderParamsMessage, @@ -6015,7 +6027,9 @@ export class DriftClient { takerInfo, orderParams, referrerInfo, - subAccountId + subAccountId, + precedingIxs, + overrideIxCount ); const { txSig, slot } = await this.sendTransaction( await this.buildTransaction(ixs, txParams), @@ -6038,14 +6052,19 @@ export class DriftClient { }, orderParams: OptionalOrderParams, referrerInfo?: ReferrerInfo, - subAccountId?: number + subAccountId?: number, + precedingIxs: TransactionInstruction[] = [], + overrideIxCount?: number ): Promise { const [swiftOrderSignatureIx, placeTakerSwiftPerpOrderIx] = await this.getPlaceSwiftTakerPerpOrderIxs( encodedSwiftOrderParamsMessage, swiftOrderParamsSignature, orderParams.marketIndex, - takerInfo + takerInfo, + undefined, + precedingIxs, + overrideIxCount ); orderParams = getOrderParams(orderParams, { marketType: MarketType.PERP }); @@ -8952,7 +8971,6 @@ export class DriftClient { preSigned?: boolean ): Promise { const isVersionedTx = this.isVersionedTransaction(tx); - if (isVersionedTx) { return this.txSender.sendVersionedTransaction( tx as VersionedTransaction, diff --git a/sdk/src/idl/drift.json b/sdk/src/idl/drift.json index 2478883af..5054baad3 100644 --- a/sdk/src/idl/drift.json +++ b/sdk/src/idl/drift.json @@ -11749,6 +11749,41 @@ } ] } + }, + { + "name": "SignatureVerificationError", + "type": { + "kind": "enum", + "variants": [ + { + "name": "InvalidEd25519InstructionProgramId" + }, + { + "name": "InvalidEd25519InstructionDataLength" + }, + { + "name": "InvalidSignatureIndex" + }, + { + "name": "InvalidSignatureOffset" + }, + { + "name": "InvalidPublicKeyOffset" + }, + { + "name": "InvalidMessageOffset" + }, + { + "name": "InvalidMessageDataSize" + }, + { + "name": "InvalidInstructionIndex" + }, + { + "name": "MessageOffsetOverflow" + } + ] + } } ], "events": [ diff --git a/sdk/src/util/pythOracleUtils.ts b/sdk/src/util/pythOracleUtils.ts index a25212ce6..f84a54604 100644 --- a/sdk/src/util/pythOracleUtils.ts +++ b/sdk/src/util/pythOracleUtils.ts @@ -102,9 +102,11 @@ const ED25519_INSTRUCTION_LAYOUT = BufferLayout.struct< export function createMinimalEd25519VerifyIx( customInstructionIndex: number, messageOffset: number, - customInstructionData: Uint8Array + customInstructionData: Uint8Array, + magicLen?: number ): TransactionInstruction { - const signatureOffset = messageOffset + MAGIC_LEN; + const signatureOffset = + messageOffset + (magicLen === undefined ? MAGIC_LEN : magicLen); const publicKeyOffset = signatureOffset + SIGNATURE_LEN; const messageDataSizeOffset = publicKeyOffset + PUBKEY_LEN; const messageDataOffset = messageDataSizeOffset + MESSAGE_SIZE_LEN; diff --git a/test-scripts/run-anchor-local-validator-tests.sh b/test-scripts/run-anchor-local-validator-tests.sh index 9b811ed80..e4640c321 100644 --- a/test-scripts/run-anchor-local-validator-tests.sh +++ b/test-scripts/run-anchor-local-validator-tests.sh @@ -6,7 +6,7 @@ fi export ANCHOR_WALLET=~/.config/solana/id.json test_files=( - pythLazer.ts + # pythLazer.ts placeAndMakeSwiftPerp.ts ) diff --git a/test-scripts/single-anchor-test.sh b/test-scripts/single-anchor-test.sh index 62eebc239..00ae718d9 100644 --- a/test-scripts/single-anchor-test.sh +++ b/test-scripts/single-anchor-test.sh @@ -6,7 +6,7 @@ fi export ANCHOR_WALLET=~/.config/solana/id.json -test_files=(pythLazerBankrun.ts) +test_files=(placeAndMakeSwiftPerpBankrun.ts) for test_file in ${test_files[@]}; do ts-mocha -t 300000 ./tests/${test_file} diff --git a/tests/placeAndMakeSwiftPerp.ts b/tests/placeAndMakeSwiftPerp.ts index 1b197c5bb..a3931144b 100644 --- a/tests/placeAndMakeSwiftPerp.ts +++ b/tests/placeAndMakeSwiftPerp.ts @@ -6,7 +6,6 @@ import { Program } from '@coral-xyz/anchor'; import { ComputeBudgetProgram, Keypair, - SendTransactionError, Transaction, TransactionMessage, VersionedTransaction, @@ -28,7 +27,6 @@ import { loadKeypair, getMarketOrderParams, MarketType, - DriftClient, } from '../sdk/src'; import { @@ -102,10 +100,11 @@ describe('place and make swift order', () => { spotMarketIndexes = [0, 1]; oracleInfos = [{ publicKey: solUsd, source: OracleSource.PYTH }]; + const wallet = new Wallet(loadKeypair(process.env.ANCHOR_WALLET)); makerDriftClient = new TestClient({ connection, //@ts-ignore - wallet: new Wallet(loadKeypair(process.env.ANCHOR_WALLET)), + wallet, programID: chProgram.programId, opts: { commitment: 'confirmed', @@ -118,6 +117,7 @@ describe('place and make swift order', () => { type: 'polling', accountLoader: bulkAccountLoader, }, + txVersion: 'legacy', }); await makerDriftClient.initialize(usdcMint.publicKey, true); await makerDriftClient.subscribe(); @@ -244,23 +244,27 @@ describe('place and make swift order', () => { takerOrderParamsMessage ); - let ixs = await makerDriftClient.getPlaceAndMakeSwiftPerpOrderIxs( - takerDriftClient.encodeSwiftOrderParamsMessage(takerOrderParamsMessage), - takerOrderParamsSig, - uuid, - { - taker: await takerDriftClient.getUserAccountPublicKey(), - takerUserAccount: takerDriftClient.getUserAccount(), - takerStats: takerDriftClient.getUserStatsAccountPublicKey(), - }, - makerOrderParams - ); - ixs = [ + const ixs = [ ComputeBudgetProgram.setComputeUnitLimit({ units: 10_000_000, }), - ...ixs, ]; + ixs.push( + ...(await makerDriftClient.getPlaceAndMakeSwiftPerpOrderIxs( + takerDriftClient.encodeSwiftOrderParamsMessage(takerOrderParamsMessage), + takerOrderParamsSig, + uuid, + { + taker: await takerDriftClient.getUserAccountPublicKey(), + takerUserAccount: takerDriftClient.getUserAccount(), + takerStats: takerDriftClient.getUserStatsAccountPublicKey(), + }, + makerOrderParams, + undefined, + undefined, + ixs + )) + ); const message = new TransactionMessage({ instructions: ixs, @@ -288,7 +292,7 @@ describe('place and make swift order', () => { provider, keypair.publicKey ); - const takerDriftClient = new DriftClient({ + const takerDriftClient = new TestClient({ connection, wallet, programID: chProgram.programId, @@ -305,11 +309,13 @@ describe('place and make swift order', () => { accountLoader: bulkAccountLoader, }, }); + await takerDriftClient.subscribe(); await takerDriftClient.initializeUserAccountAndDepositCollateral( usdcAmount, userUSDCAccount.publicKey ); + const takerDriftClientUser = new User({ driftClient: takerDriftClient, userAccountPublicKey: await takerDriftClient.getUserAccountPublicKey(), @@ -329,41 +335,51 @@ describe('place and make swift order', () => { const takerOrderParams = getMarketOrderParams({ marketIndex, direction: PositionDirection.LONG, - baseAssetAmount, + baseAssetAmount: baseAssetAmount.muln(2), price: new BN(34).mul(PRICE_PRECISION), auctionStartPrice: new BN(33).mul(PRICE_PRECISION), auctionEndPrice: new BN(34).mul(PRICE_PRECISION), auctionDuration: 10, userOrderId: 1, postOnly: PostOnlyParams.NONE, + marketType: MarketType.PERP, }); + const uuid = Uint8Array.from(Buffer.from(nanoid(8))); + const takerOrderParamsMessage: SwiftOrderParamsMessage = { + swiftOrderParams: takerOrderParams, + subAccountId: 0, + slot: new BN(await connection.getSlot()), + uuid, + takeProfitOrderParams: null, + stopLossOrderParams: null, + }; await takerDriftClientUser.fetchAccounts(); + const makerOrderParams = getLimitOrderParams({ marketIndex, direction: PositionDirection.SHORT, - baseAssetAmount, + baseAssetAmount: BASE_PRECISION, price: new BN(33).mul(PRICE_PRECISION), userOrderId: 1, postOnly: PostOnlyParams.MUST_POST_ONLY, immediateOrCancel: true, }); - const uuid = Uint8Array.from(Buffer.from(nanoid(8))); - const takerOrderParamsMessage: SwiftOrderParamsMessage = { - swiftOrderParams: takerOrderParams, - takeProfitOrderParams: null, - subAccountId: 0, - slot: new BN(await connection.getSlot()), - uuid, - stopLossOrderParams: null, - }; - const takerOrderParamsSig = makerDriftClient.signSwiftOrderParamsMessage( - takerOrderParamsMessage + const takerOrderParamsMessageEncoded = + takerDriftClient.encodeSwiftOrderParamsMessage(takerOrderParamsMessage); + const takerOrderParamsSig = takerDriftClient.signMessage( + takerOrderParamsMessageEncoded, + makerDriftClient.wallet.payer ); - try { - let ixs = await makerDriftClient.getPlaceAndMakeSwiftPerpOrderIxs( + const ixs = [ + ComputeBudgetProgram.setComputeUnitLimit({ + units: 10_000_000, + }), + ]; + ixs.push( + ...(await makerDriftClient.getPlaceAndMakeSwiftPerpOrderIxs( takerDriftClient.encodeSwiftOrderParamsMessage(takerOrderParamsMessage), takerOrderParamsSig, uuid, @@ -372,173 +388,22 @@ describe('place and make swift order', () => { takerUserAccount: takerDriftClient.getUserAccount(), takerStats: takerDriftClient.getUserStatsAccountPublicKey(), }, - makerOrderParams - ); - ixs = [ - ComputeBudgetProgram.setComputeUnitLimit({ - units: 10_000_000, - }), - ...ixs, - ]; - await makerDriftClient.sendTransaction(new Transaction().add(...ixs)); - assert.fail('Should have failed'); - } catch (error) { - assert.equal( - error.transactionMessage, - 'Transaction precompile verification failure InvalidAccountIndex' - ); - } - }); - - it('should fail if diff signed message to verify ixs vs drift ixs', async () => { - const keypair = new Keypair(); - await provider.connection.requestAirdrop(keypair.publicKey, 10 ** 9); - await sleep(1000); - const wallet = new Wallet(keypair); - const userUSDCAccount = await mockUserUSDCAccount( - usdcMint, - usdcAmount, - provider, - keypair.publicKey - ); - const takerDriftClient = new TestClient({ - connection, - wallet, - programID: chProgram.programId, - opts: { - commitment: 'confirmed', - }, - activeSubAccountId: 0, - perpMarketIndexes: marketIndexes, - spotMarketIndexes: spotMarketIndexes, - oracleInfos, - userStats: true, - accountSubscription: { - type: 'polling', - accountLoader: bulkAccountLoader, - }, - }); - - await takerDriftClient.subscribe(); - await takerDriftClient.initializeUserAccountAndDepositCollateral( - usdcAmount, - userUSDCAccount.publicKey - ); - - const takerDriftClientUser = new User({ - driftClient: takerDriftClient, - userAccountPublicKey: await takerDriftClient.getUserAccountPublicKey(), - accountSubscription: { - type: 'polling', - accountLoader: bulkAccountLoader, - }, - }); - await takerDriftClientUser.subscribe(); - await takerDriftClient.initializeSwiftUserOrders( - takerDriftClientUser.userAccountPublicKey, - 32 - ); - - const marketIndex = 0; - const baseAssetAmount = BASE_PRECISION; - const takerOrderParams = getMarketOrderParams({ - marketIndex, - direction: PositionDirection.LONG, - baseAssetAmount: baseAssetAmount.muln(2), - price: new BN(34).mul(PRICE_PRECISION), - auctionStartPrice: new BN(33).mul(PRICE_PRECISION), - auctionEndPrice: new BN(34).mul(PRICE_PRECISION), - auctionDuration: 10, - userOrderId: 1, - postOnly: PostOnlyParams.NONE, - marketType: MarketType.PERP, - }); - const takerOrderParamsMessage: SwiftOrderParamsMessage = { - swiftOrderParams: takerOrderParams, - subAccountId: 0, - uuid: Uint8Array.from(Buffer.from(nanoid(8))), - slot: new BN(await connection.getSlot()), - takeProfitOrderParams: null, - stopLossOrderParams: null, - }; - - await takerDriftClientUser.fetchAccounts(); - - // Auth for legit order - const takerOrderParamsSig = takerDriftClient.signSwiftOrderParamsMessage( - takerOrderParamsMessage - ); - - // Auth for non-legit order - const takerOrderParamsMessage2: SwiftOrderParamsMessage = { - swiftOrderParams: Object.assign({}, takerOrderParams, { - direction: PositionDirection.SHORT, - auctionStartPrice: new BN(34).mul(PRICE_PRECISION), - auctionEndPrice: new BN(33).mul(PRICE_PRECISION), - price: new BN(33).mul(PRICE_PRECISION), - }), - subAccountId: 0, - takeProfitOrderParams: null, - stopLossOrderParams: null, - uuid: Uint8Array.from(Buffer.from(nanoid(8))), - slot: new BN(await connection.getSlot()), - }; - - const takerOrderParamsSig2 = takerDriftClient.signSwiftOrderParamsMessage( - takerOrderParamsMessage2 - ); - - const computeBudgetIx = ComputeBudgetProgram.setComputeUnitLimit({ - units: 10_000_000, - }); - - const ixs = await makerDriftClient.getPlaceSwiftTakerPerpOrderIxs( - takerDriftClient.encodeSwiftOrderParamsMessage(takerOrderParamsMessage), - takerOrderParamsSig, - 0, - { - taker: await takerDriftClient.getUserAccountPublicKey(), - takerUserAccount: takerDriftClient.getUserAccount(), - takerStats: takerDriftClient.getUserStatsAccountPublicKey(), - } + makerOrderParams, + undefined, + undefined, + ixs + )) ); - const ixsForOrder2 = await makerDriftClient.getPlaceSwiftTakerPerpOrderIxs( - takerDriftClient.encodeSwiftOrderParamsMessage(takerOrderParamsMessage2), - takerOrderParamsSig2, - 0, - { - taker: await takerDriftClient.getUserAccountPublicKey(), - takerUserAccount: takerDriftClient.getUserAccount(), - takerStats: takerDriftClient.getUserStatsAccountPublicKey(), - } - ); - - const badOrderTx = new Transaction(); - badOrderTx.add(...[computeBudgetIx, ixs[0], ixsForOrder2[1]]); try { - await makerDriftClient.sendTransaction(badOrderTx); - assert.fail('Should have failed'); + const normalTx = new Transaction(); + normalTx.add(...ixs); + await makerDriftClient.sendTransaction(normalTx); + assert.fail('should have thrown'); } catch (error) { - console.log((error as SendTransactionError).logs); - assert( - (error as SendTransactionError).logs.some((log) => - log.includes('SigVerificationFailed') - ) - ); - } - - const badSwiftTx = new Transaction(); - badSwiftTx.add(...[computeBudgetIx, ixsForOrder2[0], ixs[1]]); - try { - await makerDriftClient.sendTransaction(badSwiftTx); - assert.fail('Should have failed'); - } catch (error) { - console.log((error as SendTransactionError).logs); - assert( - (error as SendTransactionError).logs.some((log) => - log.includes('SigVerificationFailed') - ) + assert.equal( + error.transactionMessage, + 'Transaction precompile verification failure InvalidAccountIndex' ); } }); diff --git a/tests/placeAndMakeSwiftPerpBankrun.ts b/tests/placeAndMakeSwiftPerpBankrun.ts index 0ba963308..40907fde0 100644 --- a/tests/placeAndMakeSwiftPerpBankrun.ts +++ b/tests/placeAndMakeSwiftPerpBankrun.ts @@ -273,7 +273,12 @@ describe('place and make swift order', () => { takerUserAccount: takerDriftClient.getUserAccount(), takerStats: takerDriftClient.getUserStatsAccountPublicKey(), }, - makerOrderParams + makerOrderParams, + undefined, + undefined, + undefined, + undefined, + 2 ); const makerPosition = makerDriftClient.getUser().getPerpPosition(0); @@ -302,7 +307,12 @@ describe('place and make swift order', () => { takerUserAccount: takerDriftClient.getUserAccount(), takerStats: takerDriftClient.getUserStatsAccountPublicKey(), }, - makerOrderParams + makerOrderParams, + undefined, + undefined, + undefined, + undefined, + 2 ); const takerPositionAfter = takerDriftClient.getUser().getPerpPosition(0); @@ -431,7 +441,9 @@ describe('place and make swift order', () => { taker: await takerDriftClient.getUserAccountPublicKey(), takerUserAccount: takerDriftClient.getUserAccount(), takerStats: takerDriftClient.getUserStatsAccountPublicKey(), - } + }, + undefined, + pythLazerCrankIxs ); const swiftOrder: Order = { @@ -596,7 +608,11 @@ describe('place and make swift order', () => { takerUserAccount: takerDriftClient.getUserAccount(), takerStats: takerDriftClient.getUserStatsAccountPublicKey(), }, - makerOrderParams + makerOrderParams, + undefined, + undefined, + undefined, + 2 ); /* @@ -703,7 +719,12 @@ describe('place and make swift order', () => { takerUserAccount: takerDriftClient.getUserAccount(), takerStats: takerDriftClient.getUserStatsAccountPublicKey(), }, - makerOrderParams + makerOrderParams, + undefined, + undefined, + undefined, + undefined, + 2 ); } catch (e) { assert(e); @@ -768,7 +789,9 @@ describe('place and make swift order', () => { taker: await takerDriftClient.getUserAccountPublicKey(), takerUserAccount: takerDriftClient.getUserAccount(), takerStats: takerDriftClient.getUserStatsAccountPublicKey(), - } + }, + undefined, + 2 ); assert(takerDriftClient.getOrderByUserId(1) !== undefined); @@ -791,7 +814,12 @@ describe('place and make swift order', () => { takerUserAccount: takerDriftClient.getUserAccount(), takerStats: takerDriftClient.getUserStatsAccountPublicKey(), }, - makerOrderParams + makerOrderParams, + undefined, + undefined, + undefined, + undefined, + 2 ); const takerPosition = takerDriftClient.getUser().getPerpPosition(0); @@ -851,7 +879,9 @@ describe('place and make swift order', () => { taker: await takerDriftClient.getUserAccountPublicKey(), takerUserAccount: takerDriftClient.getUserAccount(), takerStats: takerDriftClient.getUserStatsAccountPublicKey(), - } + }, + undefined, + 2 ); assert.fail('Should have failed'); } catch (error) { @@ -913,7 +943,9 @@ describe('place and make swift order', () => { taker: await takerDriftClient.getUserAccountPublicKey(), takerUserAccount: takerDriftClient.getUserAccount(), takerStats: takerDriftClient.getUserStatsAccountPublicKey(), - } + }, + undefined, + 2 ); assert(