Skip to content

Commit

Permalink
stash
Browse files Browse the repository at this point in the history
  • Loading branch information
MajorLift committed Feb 16, 2024
1 parent 5872606 commit e139f86
Showing 1 changed file with 121 additions and 90 deletions.
211 changes: 121 additions & 90 deletions packages/assets-controllers/src/TokenDetectionController.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import type { AccountsControllerSelectedAccountChangeEvent } from '@metamask/accounts-controller';
import type {
AccountsControllerGetSelectedAccountAction,
AccountsControllerSelectedAccountChangeEvent,
} from '@metamask/accounts-controller';
import type {
RestrictedControllerMessenger,
ControllerGetStateAction,
ControllerStateChangeEvent,
} from '@metamask/base-controller';
import contractMap from '@metamask/contract-metadata';
import {
ChainId,
safelyExecute,
toChecksumHexAddress,
} from '@metamask/controller-utils';
import { ChainId, safelyExecute } from '@metamask/controller-utils';
import type {
KeyringControllerGetStateAction,
KeyringControllerLockEvent,
KeyringControllerUnlockEvent,
} from '@metamask/keyring-controller';
import type {
NetworkClientId,
NetworkControllerNetworkDidChangeEvent,
NetworkControllerFindNetworkClientIdByChainIdAction,
NetworkControllerGetNetworkConfigurationByNetworkClientId,
NetworkControllerGetProviderConfigAction,
NetworkControllerNetworkDidChangeEvent,
} from '@metamask/network-controller';
import { StaticIntervalPollingController } from '@metamask/polling-controller';
import type {
Expand All @@ -43,13 +44,23 @@ import type {
const DEFAULT_INTERVAL = 180000;

/**
* Finds a case insensitive match in an array of strings
* @param source - An array of strings to search.
* @param target - The target string to search for.
* @returns The first match that is found.
* Compare 2 given strings and return boolean
* eg: "foo" and "FOO" => true
* eg: "foo" and "bar" => false
* eg: "foo" and 123 => false
*
* @param value1 - first string to compare
* @param value2 - first string to compare
* @returns true if 2 strings are identical when they are lowercase
*/
function findCaseInsensitiveMatch(source: string[], target: string) {
return source.find((e: string) => e.toLowerCase() === target.toLowerCase());
export function isEqualCaseInsensitive(
value1: string,
value2: string,
): boolean {
if (typeof value1 !== 'string' || typeof value2 !== 'string') {
return false;
}
return value1.toLowerCase() === value2.toLowerCase();
}

type LegacyToken = Omit<
Expand Down Expand Up @@ -95,7 +106,10 @@ export type TokenDetectionControllerActions =
TokenDetectionControllerGetStateAction;

export type AllowedActions =
| AccountsControllerGetSelectedAccountAction
| NetworkControllerGetNetworkConfigurationByNetworkClientId
| NetworkControllerGetProviderConfigAction
| NetworkControllerFindNetworkClientIdByChainIdAction
| GetTokenListState
| KeyringControllerGetStateAction
| PreferencesControllerGetStateAction
Expand Down Expand Up @@ -130,7 +144,7 @@ export type TokenDetectionControllerMessenger = RestrictedControllerMessenger<
* @property chainId - The chain ID of the current network
* @property selectedAddress - Vault selected address
* @property networkClientId - The network client ID of the current selected network
* @property disabled - Boolean to track if network requests are blocked
* @property disableLegacyInterval - Boolean to track if network requests are blocked
* @property isUnlocked - Boolean to track if the keyring state is unlocked
* @property isDetectionEnabledFromPreferences - Boolean to track if detection is enabled from PreferencesController
* @property isDetectionEnabledForNetwork - Boolean to track if detected is enabled for current network
Expand All @@ -148,7 +162,7 @@ export class TokenDetectionController extends StaticIntervalPollingController<

#networkClientId: NetworkClientId;

#disabled: boolean;
#disableLegacyInterval: boolean;

#isUnlocked: boolean;

Expand All @@ -173,7 +187,7 @@ export class TokenDetectionController extends StaticIntervalPollingController<
*
* @param options - The controller options.
* @param options.messenger - The controller messaging system.
* @param options.disabled - If set to true, all network requests are blocked.
* @param options.disableLegacyInterval - If set to true, all network requests are blocked.
* @param options.interval - Polling interval used to fetch new token rates
* @param options.networkClientId - The selected network client ID of the current network
* @param options.selectedAddress - Vault selected address
Expand All @@ -182,17 +196,17 @@ export class TokenDetectionController extends StaticIntervalPollingController<
*/
constructor({
networkClientId,
selectedAddress = '',
selectedAddress,
interval = DEFAULT_INTERVAL,
disabled = true,
disableLegacyInterval = true,
getBalancesInSingleCall,
trackMetaMetricsEvent,
messenger,
}: {
networkClientId: NetworkClientId;
selectedAddress?: string;
interval?: number;
disabled?: boolean;
disableLegacyInterval?: boolean;
getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall'];
trackMetaMetricsEvent: (options: {
event: string;
Expand All @@ -212,12 +226,19 @@ export class TokenDetectionController extends StaticIntervalPollingController<
metadata: {},
});

this.#disabled = disabled;
this.setIntervalLength(interval);
this.#disableLegacyInterval = disableLegacyInterval;
if (!this.#disableLegacyInterval) {
this.setIntervalLength(interval);
}

this.#networkClientId = networkClientId;
this.#selectedAddress = selectedAddress;
this.#chainId = this.#getCorrectChainId(networkClientId);
this.#selectedAddress =
selectedAddress ??
this.messagingSystem.call('AccountsController:getSelectedAccount')
.address;
const { chainId, networkClientId: correctNetworkClientId } =
this.#getCorrectChainIdAndNetworkClientId(networkClientId);
this.#chainId = chainId;
this.#networkClientId = correctNetworkClientId;

const { useTokenDetection: defaultUseTokenDetection } =
this.messagingSystem.call('PreferencesController:getState');
Expand Down Expand Up @@ -308,7 +329,8 @@ export class TokenDetectionController extends StaticIntervalPollingController<
const isNetworkClientIdChanged =
this.#networkClientId !== selectedNetworkClientId;

const newChainId = this.#getCorrectChainId(selectedNetworkClientId);
const { chainId: newChainId } =
this.#getCorrectChainIdAndNetworkClientId(selectedNetworkClientId);
this.#isDetectionEnabledForNetwork =
isTokenDetectionSupportedForNetwork(newChainId);

Expand All @@ -325,43 +347,43 @@ export class TokenDetectionController extends StaticIntervalPollingController<
/**
* Allows controller to make active and passive polling requests
*/
enable() {
this.#disabled = false;
enable(): void {
this.#disableLegacyInterval = false;
}

/**
* Blocks controller from making network calls
*/
disable() {
this.#disabled = true;
disable(): void {
this.#disableLegacyInterval = true;
}

/**
* Internal isActive state
*
* @type {object}
*/
get isActive() {
return !this.#disabled && this.#isUnlocked;
get isActive(): boolean {
return !this.#disableLegacyInterval && this.#isUnlocked;
}

/**
* Start polling for detected tokens.
*/
async start() {
async start(): Promise<void> {
this.enable();
await this.#startPolling();
}

/**
* Stop polling for detected tokens.
*/
stop() {
stop(): void {
this.disable();
this.#stopPolling();
}

#stopPolling() {
#stopPolling(): void {
if (this.#intervalId) {
clearInterval(this.#intervalId);
}
Expand All @@ -381,25 +403,45 @@ export class TokenDetectionController extends StaticIntervalPollingController<
}, this.getIntervalLength());
}

#getCorrectChainId(networkClientId?: NetworkClientId) {
const { chainId } =
this.messagingSystem.call(
#getCorrectChainIdAndNetworkClientId(networkClientId?: NetworkClientId): {
chainId: Hex;
networkClientId: NetworkClientId;
} {
if (networkClientId) {
const networkConfiguration = this.messagingSystem.call(
'NetworkController:getNetworkConfigurationByNetworkClientId',
networkClientId ?? this.#networkClientId,
) ?? {};
return chainId ?? this.#chainId;
networkClientId,
);
if (networkConfiguration) {
return {
chainId: networkConfiguration.chainId,
networkClientId,
};
}
}
const { chainId } = this.messagingSystem.call(
'NetworkController:getProviderConfig',
);
const newNetworkClientId = this.messagingSystem.call(
'NetworkController:findNetworkClientIdByChainId',
this.#chainId,
);
return {
chainId,
networkClientId: newNetworkClientId,
};
}

async _executePoll(
networkClientId: string,
networkClientId: NetworkClientId,
options: { address: string },
): Promise<void> {
if (!this.isActive) {
return;
}
await this.detectTokens({
...options,
networkClientId,
accountAddress: options.address,
});
}

Expand All @@ -414,7 +456,10 @@ export class TokenDetectionController extends StaticIntervalPollingController<
async #restartTokenDetection({
selectedAddress,
networkClientId,
}: { selectedAddress?: string; networkClientId?: string } = {}) {
}: {
selectedAddress?: string;
networkClientId?: NetworkClientId;
} = {}): Promise<void> {
await this.detectTokens({
networkClientId,
accountAddress: selectedAddress,
Expand All @@ -440,8 +485,10 @@ export class TokenDetectionController extends StaticIntervalPollingController<
if (!this.isActive || !this.#isDetectionEnabledForNetwork) {
return;
}

const selectedAddress = accountAddress ?? this.#selectedAddress;
const chainId = this.#getCorrectChainId(networkClientId);
const { chainId, networkClientId: selectedNetworkClientId } =
this.#getCorrectChainIdAndNetworkClientId(networkClientId);

if (
!this.#isDetectionEnabledFromPreferences &&
Expand All @@ -454,28 +501,30 @@ export class TokenDetectionController extends StaticIntervalPollingController<
const { tokensChainsCache } = this.messagingSystem.call(
'TokenListController:getState',
);
const tokenList = tokensChainsCache[chainId]?.data || {};

const tokenListUsed = isTokenDetectionInactiveInMainnet
const tokenList = isTokenDetectionInactiveInMainnet
? STATIC_MAINNET_TOKEN_LIST
: tokenList;
: tokensChainsCache[chainId]?.data ?? {};

const { allTokens, allDetectedTokens, allIgnoredTokens } =
this.messagingSystem.call('TokensController:getState');
const tokens = allTokens[chainId]?.[selectedAddress] || [];
const detectedTokens = allDetectedTokens[chainId]?.[selectedAddress] || [];
const ignoredTokens = allIgnoredTokens[chainId]?.[selectedAddress] || [];
const [tokensAddresses, detectedTokensAddresses, ignoredTokensAddreses] = [
allTokens,
allDetectedTokens,
allIgnoredTokens,
].map((tokens) =>
(tokens[chainId]?.[selectedAddress] ?? []).map((value) =>
typeof value === 'string' ? value : value.address,
),
);

const tokensToDetect: string[] = [];
for (const tokenAddress of Object.keys(tokenListUsed)) {
for (const tokenAddress of Object.keys(tokenList)) {
if (
!findCaseInsensitiveMatch(
tokens.map(({ address }) => address),
tokenAddress,
) &&
!findCaseInsensitiveMatch(
detectedTokens.map(({ address }) => address),
tokenAddress,
[tokensAddresses, detectedTokensAddresses, ignoredTokensAddreses].every(
(addresses) =>
!addresses.find((address) =>
isEqualCaseInsensitive(address, tokenAddress),
),
)
) {
tokensToDetect.push(tokenAddress);
Expand All @@ -502,40 +551,22 @@ export class TokenDetectionController extends StaticIntervalPollingController<
const balances = await this.#getBalancesInSingleCall(
selectedAddress,
tokensSlice,
selectedNetworkClientId,
);
const tokensToAdd: Token[] = [];

const tokensWithBalance: Token[] = [];
const eventTokensDetails: string[] = [];
let ignored;
for (const tokenAddress of Object.keys(balances)) {
if (ignoredTokens.length) {
ignored = ignoredTokens.find(
(ignoredTokenAddress) =>
ignoredTokenAddress === toChecksumHexAddress(tokenAddress),
);
}
const caseInsensitiveTokenKey =
findCaseInsensitiveMatch(
Object.keys(tokenListUsed),
tokenAddress,
) ?? '';

if (ignored === undefined) {
const { decimals, symbol, aggregators, iconUrl, name } =
tokenListUsed[caseInsensitiveTokenKey];
eventTokensDetails.push(`${symbol} - ${tokenAddress}`);
tokensToAdd.push({
address: tokenAddress,
decimals,
symbol,
aggregators,
image: iconUrl,
isERC721: false,
name,
});
}
for (const nonZeroTokenAddress of Object.keys(balances)) {
const { address, symbol, decimals } = tokenList[nonZeroTokenAddress];
eventTokensDetails.push(`${symbol} - ${nonZeroTokenAddress}`);
tokensWithBalance.push({
address,
symbol,
decimals,
});
}

if (tokensToAdd.length) {
if (tokensWithBalance.length) {
this.#trackMetaMetricsEvent({
event: 'Token Detected',
category: 'Wallet',
Expand All @@ -547,7 +578,7 @@ export class TokenDetectionController extends StaticIntervalPollingController<
});
await this.messagingSystem.call(
'TokensController:addDetectedTokens',
tokensToAdd,
tokensWithBalance,
{
selectedAddress,
chainId,
Expand Down

0 comments on commit e139f86

Please sign in to comment.