Skip to content

Commit

Permalink
Merge pull request #1124 from drift-labs/chester/improve-async-logic
Browse files Browse the repository at this point in the history
refactor(sdk): improve async logic in drift client account sub
  • Loading branch information
ChesterSim committed Jul 8, 2024
2 parents a074ac4 + 30f9eb9 commit d54b70e
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 61 deletions.
61 changes: 38 additions & 23 deletions sdk/src/accounts/pollingDriftClientAccountSubscriber.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ export class PollingDriftClientAccountSubscriber
}

await this.updateAccountsToPoll();
await this.updateOraclesToPoll();
this.updateOraclesToPoll();
await this.addToAccountLoader();

let subscriptionSucceeded = false;
Expand All @@ -116,8 +116,7 @@ export class PollingDriftClientAccountSubscriber
this.eventEmitter.emit('update');
}

await this.setPerpOracleMap();
await this.setSpotOracleMap();
await Promise.all([this.setPerpOracleMap(), this.setSpotOracleMap()]);

this.isSubscribing = false;
this.isSubscribed = subscriptionSucceeded;
Expand All @@ -141,14 +140,18 @@ export class PollingDriftClientAccountSubscriber
eventType: 'stateAccountUpdate',
});

await this.updatePerpMarketAccountsToPoll();
await this.updateSpotMarketAccountsToPoll();
await Promise.all([
this.updatePerpMarketAccountsToPoll(),
this.updateSpotMarketAccountsToPoll(),
]);
}

async updatePerpMarketAccountsToPoll(): Promise<boolean> {
for (const marketIndex of this.perpMarketIndexes) {
await this.addPerpMarketAccountToPoll(marketIndex);
}
await Promise.all(
this.perpMarketIndexes.map((marketIndex) => {
return this.addPerpMarketAccountToPoll(marketIndex);
})
);
return true;
}

Expand All @@ -169,9 +172,11 @@ export class PollingDriftClientAccountSubscriber
}

async updateSpotMarketAccountsToPoll(): Promise<boolean> {
for (const marketIndex of this.spotMarketIndexes) {
await this.addSpotMarketAccountToPoll(marketIndex);
}
await Promise.all(
this.spotMarketIndexes.map(async (marketIndex) => {
await this.addSpotMarketAccountToPoll(marketIndex);
})
);

return true;
}
Expand Down Expand Up @@ -209,16 +214,19 @@ export class PollingDriftClientAccountSubscriber

return true;
}

async addToAccountLoader(): Promise<void> {
const accountPromises = [];
for (const [_, accountToPoll] of this.accountsToPoll) {
await this.addAccountToAccountLoader(accountToPoll);
accountPromises.push(this.addAccountToAccountLoader(accountToPoll));
}

const oraclePromises = [];
for (const [_, oracleToPoll] of this.oraclesToPoll) {
await this.addOracleToAccountLoader(oracleToPoll);
oraclePromises.push(this.addOracleToAccountLoader(oracleToPoll));
}

await Promise.all([...accountPromises, ...oraclePromises]);

this.errorCallbackId = this.accountLoader.addErrorCallbacks((error) => {
this.eventEmitter.emit('error', error);
});
Expand Down Expand Up @@ -446,37 +454,44 @@ export class PollingDriftClientAccountSubscriber
}
console.log(`Pausing to find oracle ${oracle} failed`);
}

async setPerpOracleMap() {
const perpMarkets = this.getMarketAccountsAndSlots();
const oraclePromises = [];
for (const perpMarket of perpMarkets) {
const perpMarketAccount = perpMarket.data;
const perpMarketIndex = perpMarketAccount.marketIndex;
const oracle = perpMarketAccount.amm.oracle;
if (!this.oracles.has(oracle.toBase58())) {
await this.addOracle({
publicKey: oracle,
source: perpMarketAccount.amm.oracleSource,
});
oraclePromises.push(
this.addOracle({
publicKey: oracle,
source: perpMarketAccount.amm.oracleSource,
})
);
}
this.perpOracleMap.set(perpMarketIndex, oracle);
}
await Promise.all(oraclePromises);
}

async setSpotOracleMap() {
const spotMarkets = this.getSpotMarketAccountsAndSlots();
const oraclePromises = [];
for (const spotMarket of spotMarkets) {
const spotMarketAccount = spotMarket.data;
const spotMarketIndex = spotMarketAccount.marketIndex;
const oracle = spotMarketAccount.oracle;
if (!this.oracles.has(oracle.toBase58())) {
await this.addOracle({
publicKey: oracle,
source: spotMarketAccount.oracleSource,
});
oraclePromises.push(
this.addOracle({
publicKey: oracle,
source: spotMarketAccount.oracleSource,
})
);
}
this.spotOracleMap.set(spotMarketIndex, oracle);
}
await Promise.all(oraclePromises);
}

assertIsSubscribed(): void {
Expand Down
93 changes: 55 additions & 38 deletions sdk/src/accounts/webSocketDriftClientAccountSubscriber.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,18 @@ export class WebSocketDriftClientAccountSubscriber
this.eventEmitter.emit('update');
});

// subscribe to market accounts
await this.subscribeToPerpMarketAccounts();

// subscribe to spot market accounts
await this.subscribeToSpotMarketAccounts();

// subscribe to oracles
await this.subscribeToOracles();
await Promise.all([
// subscribe to market accounts
this.subscribeToPerpMarketAccounts(),
// subscribe to spot market accounts
this.subscribeToSpotMarketAccounts(),
// subscribe to oracles
this.subscribeToOracles(),
]);

this.eventEmitter.emit('update');

await this.setPerpOracleMap();
await this.setSpotOracleMap();
await Promise.all([this.setPerpOracleMap(), this.setSpotOracleMap()]);

this.isSubscribing = false;
this.isSubscribed = true;
Expand All @@ -137,9 +136,11 @@ export class WebSocketDriftClientAccountSubscriber
}

async subscribeToPerpMarketAccounts(): Promise<boolean> {
for (const marketIndex of this.perpMarketIndexes) {
await this.subscribeToPerpMarketAccount(marketIndex);
}
await Promise.all(
this.perpMarketIndexes.map((marketIndex) =>
this.subscribeToPerpMarketAccount(marketIndex)
)
);
return true;
}

Expand All @@ -165,9 +166,11 @@ export class WebSocketDriftClientAccountSubscriber
}

async subscribeToSpotMarketAccounts(): Promise<boolean> {
for (const marketIndex of this.spotMarketIndexes) {
await this.subscribeToSpotMarketAccount(marketIndex);
}
await Promise.all(
this.spotMarketIndexes.map((marketIndex) =>
this.subscribeToSpotMarketAccount(marketIndex)
)
);
return true;
}

Expand All @@ -193,11 +196,11 @@ export class WebSocketDriftClientAccountSubscriber
}

async subscribeToOracles(): Promise<boolean> {
for (const oracleInfo of this.oracleInfos) {
if (!oracleInfo.publicKey.equals(PublicKey.default)) {
await this.subscribeToOracle(oracleInfo);
}
}
await Promise.all(
this.oracleInfos
.filter((oracleInfo) => !oracleInfo.publicKey.equals(PublicKey.default))
.map((oracleInfo) => this.subscribeToOracle(oracleInfo))
);

return true;
}
Expand Down Expand Up @@ -232,21 +235,27 @@ export class WebSocketDriftClientAccountSubscriber
}

async unsubscribeFromMarketAccounts(): Promise<void> {
for (const accountSubscriber of this.perpMarketAccountSubscribers.values()) {
await accountSubscriber.unsubscribe();
}
await Promise.all(
Array.from(this.perpMarketAccountSubscribers.values()).map(
(accountSubscriber) => accountSubscriber.unsubscribe()
)
);
}

async unsubscribeFromSpotMarketAccounts(): Promise<void> {
for (const accountSubscriber of this.spotMarketAccountSubscribers.values()) {
await accountSubscriber.unsubscribe();
}
await Promise.all(
Array.from(this.spotMarketAccountSubscribers.values()).map(
(accountSubscriber) => accountSubscriber.unsubscribe()
)
);
}

async unsubscribeFromOracles(): Promise<void> {
for (const accountSubscriber of this.oracleSubscribers.values()) {
await accountSubscriber.unsubscribe();
}
await Promise.all(
Array.from(this.oracleSubscribers.values()).map((accountSubscriber) =>
accountSubscriber.unsubscribe()
)
);
}

public async fetch(): Promise<void> {
Expand Down Expand Up @@ -315,6 +324,7 @@ export class WebSocketDriftClientAccountSubscriber

async setPerpOracleMap() {
const perpMarkets = this.getMarketAccountsAndSlots();
const addOraclePromises = [];
for (const perpMarket of perpMarkets) {
if (!perpMarket) {
continue;
Expand All @@ -323,17 +333,21 @@ export class WebSocketDriftClientAccountSubscriber
const perpMarketIndex = perpMarketAccount.marketIndex;
const oracle = perpMarketAccount.amm.oracle;
if (!this.oracleSubscribers.has(oracle.toBase58())) {
await this.addOracle({
publicKey: oracle,
source: perpMarket.data.amm.oracleSource,
});
addOraclePromises.push(
this.addOracle({
publicKey: oracle,
source: perpMarket.data.amm.oracleSource,
})
);
}
this.perpOracleMap.set(perpMarketIndex, oracle);
}
await Promise.all(addOraclePromises);
}

async setSpotOracleMap() {
const spotMarkets = this.getSpotMarketAccountsAndSlots();
const addOraclePromises = [];
for (const spotMarket of spotMarkets) {
if (!spotMarket) {
continue;
Expand All @@ -342,13 +356,16 @@ export class WebSocketDriftClientAccountSubscriber
const spotMarketIndex = spotMarketAccount.marketIndex;
const oracle = spotMarketAccount.oracle;
if (!this.oracleSubscribers.has(oracle.toBase58())) {
await this.addOracle({
publicKey: oracle,
source: spotMarketAccount.oracleSource,
});
addOraclePromises.push(
this.addOracle({
publicKey: oracle,
source: spotMarketAccount.oracleSource,
})
);
}
this.spotOracleMap.set(spotMarketIndex, oracle);
}
await Promise.all(addOraclePromises);
}

assertIsSubscribed(): void {
Expand Down

0 comments on commit d54b70e

Please sign in to comment.