diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index 04664885776..8da1d7bb3d0 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -1,15 +1,14 @@ -import type { AccountsControllerSelectedAccountChangeEvent } from '@metamask/accounts-controller'; +import type { + AccountsControllerGetSelectedAccountAction, + AccountsControllerSelectedAccountChangeEvent, +} from '@metamask/accounts-controller'; import type { RestrictedControllerMessenger, ControllerGetStateAction, ControllerStateChangeEvent, } from '@metamask/base-controller'; import contractMap from '@metamask/contract-metadata'; -import { - ChainId, - safelyExecute, - toChecksumHexAddress, -} from '@metamask/controller-utils'; +import { ChainId, safelyExecute } from '@metamask/controller-utils'; import type { KeyringControllerGetStateAction, KeyringControllerLockEvent, @@ -17,8 +16,10 @@ import type { } from '@metamask/keyring-controller'; import type { NetworkClientId, - NetworkControllerNetworkDidChangeEvent, + NetworkControllerFindNetworkClientIdByChainIdAction, NetworkControllerGetNetworkConfigurationByNetworkClientId, + NetworkControllerGetProviderConfigAction, + NetworkControllerNetworkDidChangeEvent, } from '@metamask/network-controller'; import { StaticIntervalPollingController } from '@metamask/polling-controller'; import type { @@ -43,13 +44,23 @@ import type { const DEFAULT_INTERVAL = 180000; /** - * Finds a case insensitive match in an array of strings - * @param source - An array of strings to search. - * @param target - The target string to search for. - * @returns The first match that is found. + * Compare 2 given strings and return boolean + * eg: "foo" and "FOO" => true + * eg: "foo" and "bar" => false + * eg: "foo" and 123 => false + * + * @param value1 - first string to compare + * @param value2 - first string to compare + * @returns true if 2 strings are identical when they are lowercase */ -function findCaseInsensitiveMatch(source: string[], target: string) { - return source.find((e: string) => e.toLowerCase() === target.toLowerCase()); +export function isEqualCaseInsensitive( + value1: string, + value2: string, +): boolean { + if (typeof value1 !== 'string' || typeof value2 !== 'string') { + return false; + } + return value1.toLowerCase() === value2.toLowerCase(); } type LegacyToken = Omit< @@ -95,7 +106,10 @@ export type TokenDetectionControllerActions = TokenDetectionControllerGetStateAction; export type AllowedActions = + | AccountsControllerGetSelectedAccountAction | NetworkControllerGetNetworkConfigurationByNetworkClientId + | NetworkControllerGetProviderConfigAction + | NetworkControllerFindNetworkClientIdByChainIdAction | GetTokenListState | KeyringControllerGetStateAction | PreferencesControllerGetStateAction @@ -130,7 +144,7 @@ export type TokenDetectionControllerMessenger = RestrictedControllerMessenger< * @property chainId - The chain ID of the current network * @property selectedAddress - Vault selected address * @property networkClientId - The network client ID of the current selected network - * @property disabled - Boolean to track if network requests are blocked + * @property disableLegacyInterval - Boolean to track if network requests are blocked * @property isUnlocked - Boolean to track if the keyring state is unlocked * @property isDetectionEnabledFromPreferences - Boolean to track if detection is enabled from PreferencesController * @property isDetectionEnabledForNetwork - Boolean to track if detected is enabled for current network @@ -148,7 +162,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< #networkClientId: NetworkClientId; - #disabled: boolean; + #disableLegacyInterval: boolean; #isUnlocked: boolean; @@ -173,7 +187,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< * * @param options - The controller options. * @param options.messenger - The controller messaging system. - * @param options.disabled - If set to true, all network requests are blocked. + * @param options.disableLegacyInterval - If set to true, all network requests are blocked. * @param options.interval - Polling interval used to fetch new token rates * @param options.networkClientId - The selected network client ID of the current network * @param options.selectedAddress - Vault selected address @@ -182,9 +196,9 @@ export class TokenDetectionController extends StaticIntervalPollingController< */ constructor({ networkClientId, - selectedAddress = '', + selectedAddress, interval = DEFAULT_INTERVAL, - disabled = true, + disableLegacyInterval = true, getBalancesInSingleCall, trackMetaMetricsEvent, messenger, @@ -192,7 +206,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< networkClientId: NetworkClientId; selectedAddress?: string; interval?: number; - disabled?: boolean; + disableLegacyInterval?: boolean; getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; trackMetaMetricsEvent: (options: { event: string; @@ -212,12 +226,19 @@ export class TokenDetectionController extends StaticIntervalPollingController< metadata: {}, }); - this.#disabled = disabled; - this.setIntervalLength(interval); + this.#disableLegacyInterval = disableLegacyInterval; + if (!this.#disableLegacyInterval) { + this.setIntervalLength(interval); + } - this.#networkClientId = networkClientId; - this.#selectedAddress = selectedAddress; - this.#chainId = this.#getCorrectChainId(networkClientId); + this.#selectedAddress = + selectedAddress ?? + this.messagingSystem.call('AccountsController:getSelectedAccount') + .address; + const { chainId, networkClientId: correctNetworkClientId } = + this.#getCorrectChainIdAndNetworkClientId(networkClientId); + this.#chainId = chainId; + this.#networkClientId = correctNetworkClientId; const { useTokenDetection: defaultUseTokenDetection } = this.messagingSystem.call('PreferencesController:getState'); @@ -308,7 +329,8 @@ export class TokenDetectionController extends StaticIntervalPollingController< const isNetworkClientIdChanged = this.#networkClientId !== selectedNetworkClientId; - const newChainId = this.#getCorrectChainId(selectedNetworkClientId); + const { chainId: newChainId } = + this.#getCorrectChainIdAndNetworkClientId(selectedNetworkClientId); this.#isDetectionEnabledForNetwork = isTokenDetectionSupportedForNetwork(newChainId); @@ -325,15 +347,15 @@ export class TokenDetectionController extends StaticIntervalPollingController< /** * Allows controller to make active and passive polling requests */ - enable() { - this.#disabled = false; + enable(): void { + this.#disableLegacyInterval = false; } /** * Blocks controller from making network calls */ - disable() { - this.#disabled = true; + disable(): void { + this.#disableLegacyInterval = true; } /** @@ -341,14 +363,14 @@ export class TokenDetectionController extends StaticIntervalPollingController< * * @type {object} */ - get isActive() { - return !this.#disabled && this.#isUnlocked; + get isActive(): boolean { + return !this.#disableLegacyInterval && this.#isUnlocked; } /** * Start polling for detected tokens. */ - async start() { + async start(): Promise { this.enable(); await this.#startPolling(); } @@ -356,12 +378,12 @@ export class TokenDetectionController extends StaticIntervalPollingController< /** * Stop polling for detected tokens. */ - stop() { + stop(): void { this.disable(); this.#stopPolling(); } - #stopPolling() { + #stopPolling(): void { if (this.#intervalId) { clearInterval(this.#intervalId); } @@ -381,25 +403,45 @@ export class TokenDetectionController extends StaticIntervalPollingController< }, this.getIntervalLength()); } - #getCorrectChainId(networkClientId?: NetworkClientId) { - const { chainId } = - this.messagingSystem.call( + #getCorrectChainIdAndNetworkClientId(networkClientId?: NetworkClientId): { + chainId: Hex; + networkClientId: NetworkClientId; + } { + if (networkClientId) { + const networkConfiguration = this.messagingSystem.call( 'NetworkController:getNetworkConfigurationByNetworkClientId', - networkClientId ?? this.#networkClientId, - ) ?? {}; - return chainId ?? this.#chainId; + networkClientId, + ); + if (networkConfiguration) { + return { + chainId: networkConfiguration.chainId, + networkClientId, + }; + } + } + const { chainId } = this.messagingSystem.call( + 'NetworkController:getProviderConfig', + ); + const newNetworkClientId = this.messagingSystem.call( + 'NetworkController:findNetworkClientIdByChainId', + this.#chainId, + ); + return { + chainId, + networkClientId: newNetworkClientId, + }; } async _executePoll( - networkClientId: string, + networkClientId: NetworkClientId, options: { address: string }, ): Promise { if (!this.isActive) { return; } await this.detectTokens({ + ...options, networkClientId, - accountAddress: options.address, }); } @@ -414,7 +456,10 @@ export class TokenDetectionController extends StaticIntervalPollingController< async #restartTokenDetection({ selectedAddress, networkClientId, - }: { selectedAddress?: string; networkClientId?: string } = {}) { + }: { + selectedAddress?: string; + networkClientId?: NetworkClientId; + } = {}): Promise { await this.detectTokens({ networkClientId, accountAddress: selectedAddress, @@ -440,8 +485,10 @@ export class TokenDetectionController extends StaticIntervalPollingController< if (!this.isActive || !this.#isDetectionEnabledForNetwork) { return; } + const selectedAddress = accountAddress ?? this.#selectedAddress; - const chainId = this.#getCorrectChainId(networkClientId); + const { chainId, networkClientId: selectedNetworkClientId } = + this.#getCorrectChainIdAndNetworkClientId(networkClientId); if ( !this.#isDetectionEnabledFromPreferences && @@ -454,28 +501,30 @@ export class TokenDetectionController extends StaticIntervalPollingController< const { tokensChainsCache } = this.messagingSystem.call( 'TokenListController:getState', ); - const tokenList = tokensChainsCache[chainId]?.data || {}; - - const tokenListUsed = isTokenDetectionInactiveInMainnet + const tokenList = isTokenDetectionInactiveInMainnet ? STATIC_MAINNET_TOKEN_LIST - : tokenList; + : tokensChainsCache[chainId]?.data ?? {}; const { allTokens, allDetectedTokens, allIgnoredTokens } = this.messagingSystem.call('TokensController:getState'); - const tokens = allTokens[chainId]?.[selectedAddress] || []; - const detectedTokens = allDetectedTokens[chainId]?.[selectedAddress] || []; - const ignoredTokens = allIgnoredTokens[chainId]?.[selectedAddress] || []; + const [tokensAddresses, detectedTokensAddresses, ignoredTokensAddreses] = [ + allTokens, + allDetectedTokens, + allIgnoredTokens, + ].map((tokens) => + (tokens[chainId]?.[selectedAddress] ?? []).map((value) => + typeof value === 'string' ? value : value.address, + ), + ); const tokensToDetect: string[] = []; - for (const tokenAddress of Object.keys(tokenListUsed)) { + for (const tokenAddress of Object.keys(tokenList)) { if ( - !findCaseInsensitiveMatch( - tokens.map(({ address }) => address), - tokenAddress, - ) && - !findCaseInsensitiveMatch( - detectedTokens.map(({ address }) => address), - tokenAddress, + [tokensAddresses, detectedTokensAddresses, ignoredTokensAddreses].every( + (addresses) => + !addresses.find((address) => + isEqualCaseInsensitive(address, tokenAddress), + ), ) ) { tokensToDetect.push(tokenAddress); @@ -502,40 +551,22 @@ export class TokenDetectionController extends StaticIntervalPollingController< const balances = await this.#getBalancesInSingleCall( selectedAddress, tokensSlice, + selectedNetworkClientId, ); - const tokensToAdd: Token[] = []; + + const tokensWithBalance: Token[] = []; const eventTokensDetails: string[] = []; - let ignored; - for (const tokenAddress of Object.keys(balances)) { - if (ignoredTokens.length) { - ignored = ignoredTokens.find( - (ignoredTokenAddress) => - ignoredTokenAddress === toChecksumHexAddress(tokenAddress), - ); - } - const caseInsensitiveTokenKey = - findCaseInsensitiveMatch( - Object.keys(tokenListUsed), - tokenAddress, - ) ?? ''; - - if (ignored === undefined) { - const { decimals, symbol, aggregators, iconUrl, name } = - tokenListUsed[caseInsensitiveTokenKey]; - eventTokensDetails.push(`${symbol} - ${tokenAddress}`); - tokensToAdd.push({ - address: tokenAddress, - decimals, - symbol, - aggregators, - image: iconUrl, - isERC721: false, - name, - }); - } + for (const nonZeroTokenAddress of Object.keys(balances)) { + const { address, symbol, decimals } = tokenList[nonZeroTokenAddress]; + eventTokensDetails.push(`${symbol} - ${nonZeroTokenAddress}`); + tokensWithBalance.push({ + address, + symbol, + decimals, + }); } - if (tokensToAdd.length) { + if (tokensWithBalance.length) { this.#trackMetaMetricsEvent({ event: 'Token Detected', category: 'Wallet', @@ -547,7 +578,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< }); await this.messagingSystem.call( 'TokensController:addDetectedTokens', - tokensToAdd, + tokensWithBalance, { selectedAddress, chainId,