diff --git a/src/vanilla/store.ts b/src/vanilla/store.ts index 6c20762127..3af9793916 100644 --- a/src/vanilla/store.ts +++ b/src/vanilla/store.ts @@ -173,11 +173,11 @@ type BatchPriority = 0 | 1 | 2 type Batch = [ /** finish recompute */ - priority0: Set<() => void>, + priority0: Set<(batch: Batch) => void>, /** atom listeners */ - priority1: Set<() => void>, + priority1: Set<(batch: Batch) => void>, /** atom mount hooks */ - priority2: Set<() => void>, + priority2: Set<(batch: Batch) => void>, ] & { /** changed Atoms */ C: Set @@ -189,7 +189,7 @@ const createBatch = (): Batch => const addBatchFunc = ( batch: Batch, priority: BatchPriority, - fn: () => void, + fn: (batch: Batch) => void, ) => { batch[priority].add(fn) } @@ -203,7 +203,9 @@ const registerBatchAtom = ( batch.C.add(atom) atomState.u?.(batch) const scheduleListeners = () => { - atomState.m?.l.forEach((listener) => addBatchFunc(batch, 1, listener)) + atomState.m?.l.forEach((listener) => + addBatchFunc(batch, 1, () => listener()), + ) } addBatchFunc(batch, 1, scheduleListeners) } @@ -212,9 +214,9 @@ const registerBatchAtom = ( const flushBatch = (batch: Batch) => { let error: AnyError let hasError = false - const call = (fn: () => void) => { + const call = (fn: (batch: Batch) => void) => { try { - fn() + fn(batch) } catch (e) { if (!hasError) { error = e @@ -223,7 +225,6 @@ const flushBatch = (batch: Batch) => { } } while (batch.C.size || batch.some((channel) => channel.size)) { - batch.C.clear() for (const channel of batch) { channel.forEach(call) channel.clear() @@ -315,7 +316,6 @@ const buildStore = (...storeArgs: StoreArgs): Store => { atomState.v = valueOrPromise } delete atomState.e - delete atomState.x if (!hasPrevValue || !Object.is(prevValue, atomState.v)) { ++atomState.n if (pendingPromise) { @@ -428,7 +428,6 @@ const buildStore = (...storeArgs: StoreArgs): Store => { } catch (error) { delete atomState.v atomState.e = error - delete atomState.x ++atomState.n return atomState } finally { @@ -458,11 +457,26 @@ const buildStore = (...storeArgs: StoreArgs): Store => { return dependents } - const recomputeDependents = ( - batch: Batch, - atom: Atom, - atomState: AtomState, - ) => { + const dirtyDependents = (atomState: AtomState) => { + const dependents = new Set([atomState]) + const stack: AtomState[] = [atomState] + while (stack.length > 0) { + const aState = stack.pop()! + if (aState.x) { + // already dirty + continue + } + aState.x = true + for (const [, s] of getMountedOrBatchDependents(aState)) { + if (!dependents.has(s)) { + dependents.add(s) + stack.push(s) + } + } + } + } + + const recomputeDependents = (batch: Batch) => { // Step 1: traverse the dependency graph to build the topsorted atom list // We don't bother to check for cycles, which simplifies the algorithm. // This is a topological sort via depth-first search, slightly modified from @@ -477,7 +491,10 @@ const buildStore = (...storeArgs: StoreArgs): Store => { const visited = new Set() // Visit the root atom. This is the only atom in the dependency graph // without incoming edges, which is one reason we can simplify the algorithm - const stack: [a: AnyAtom, aState: AtomState][] = [[atom, atomState]] + const stack: [a: AnyAtom, aState: AtomState][] = Array.from( + batch.C, + (atom) => [atom, ensureAtomState(atom)], + ) while (stack.length > 0) { const [a, aState] = stack[stack.length - 1]! if (visited.has(a)) { @@ -492,8 +509,6 @@ const buildStore = (...storeArgs: StoreArgs): Store => { topSortedReversed.push([a, aState, aState.n]) // Atom has been visited but not yet processed visited.add(a) - // Mark atom dirty - aState.x = true stack.pop() continue } @@ -508,29 +523,25 @@ const buildStore = (...storeArgs: StoreArgs): Store => { // Step 2: use the topSortedReversed atom list to recompute all affected atoms // Track what's changed, so that we can short circuit when possible - const finishRecompute = () => { - const changedAtoms = new Set([atom]) - for (let i = topSortedReversed.length - 1; i >= 0; --i) { - const [a, aState, prevEpochNumber] = topSortedReversed[i]! - let hasChangedDeps = false - for (const dep of aState.d.keys()) { - if (dep !== a && changedAtoms.has(dep)) { - hasChangedDeps = true - break - } + for (let i = topSortedReversed.length - 1; i >= 0; --i) { + const [a, aState, prevEpochNumber] = topSortedReversed[i]! + let hasChangedDeps = false + for (const dep of aState.d.keys()) { + if (dep !== a && batch.C.has(dep)) { + hasChangedDeps = true + break } - if (hasChangedDeps) { - readAtomState(batch, a) - mountDependencies(batch, a, aState) - if (prevEpochNumber !== aState.n) { - registerBatchAtom(batch, a, aState) - changedAtoms.add(a) - } + } + if (hasChangedDeps) { + readAtomState(batch, a) + mountDependencies(batch, a, aState) + if (prevEpochNumber !== aState.n) { + registerBatchAtom(batch, a, aState) } - delete aState.x } + delete aState.x } - addBatchFunc(batch, 0, finishRecompute) + batch.C.clear() } const writeAtomState = ( @@ -557,8 +568,9 @@ const buildStore = (...storeArgs: StoreArgs): Store => { setAtomStateValueOrPromise(a, aState, v) mountDependencies(batch, a, aState) if (prevEpochNumber !== aState.n) { + dirtyDependents(aState) registerBatchAtom(batch, a, aState) - recomputeDependents(batch, a, aState) + addBatchFunc(batch, 0, recomputeDependents) } return undefined as R } else { @@ -679,7 +691,7 @@ const buildStore = (...storeArgs: StoreArgs): Store => { // unmount self const onUnmount = atomState.m.u if (onUnmount) { - addBatchFunc(batch, 2, () => onUnmount(batch)) + addBatchFunc(batch, 2, onUnmount) } delete atomState.m atomState.h?.(batch) diff --git a/tests/vanilla/effect.test.ts b/tests/vanilla/effect.test.ts index 697d6a6733..7ebe1b8d77 100644 --- a/tests/vanilla/effect.test.ts +++ b/tests/vanilla/effect.test.ts @@ -17,8 +17,6 @@ type Ref = { cleanup?: Cleanup | undefined } -const syncEffectChannelSymbol = Symbol() - function syncEffect(effect: Effect): Atom { const refAtom = atom(() => ({ inProgress: 0, epoch: 0 })) const refreshAtom = atom(0) @@ -62,8 +60,7 @@ function syncEffect(effect: Effect): Atom { store.set(refreshAtom, (v) => v + 1) } else { // unmount - const syncEffectChannel = ensureBatchChannel(batch) - syncEffectChannel.add(() => { + scheduleListenerAfterRecompute(batch, () => { ref.cleanup?.() delete ref.cleanup }) @@ -73,8 +70,22 @@ function syncEffect(effect: Effect): Atom { internalAtomState.u = (batch) => { originalUpdateHook?.(batch) // update - const syncEffectChannel = ensureBatchChannel(batch) - syncEffectChannel.add(runEffect) + scheduleListenerAfterRecompute(batch, runEffect) + } + function scheduleListenerAfterRecompute( + batch: Batch, + listener: () => void, + ) { + const scheduleListener = () => { + batch[0].add(listener) + } + if (batch[0].size === 0) { + // no other listeners + // schedule after recomputeDependents + batch[0].add(scheduleListener) + } else { + scheduleListener() + } } } return atom((get) => { @@ -82,39 +93,6 @@ function syncEffect(effect: Effect): Atom { }) } -type BatchWithSyncEffect = Batch & { - [syncEffectChannelSymbol]?: Set<() => void> -} -function ensureBatchChannel(batch: BatchWithSyncEffect) { - // ensure continuation of the flushBatch while loop - const originalQueue = batch[1] - if (!originalQueue) { - throw new Error('batch[1] must be present') - } - if (!batch[syncEffectChannelSymbol]) { - batch[syncEffectChannelSymbol] = new Set<() => void>() - batch[1] = { - ...originalQueue, - add(item) { - originalQueue.add(item) - return this - }, - clear() { - batch[syncEffectChannelSymbol]!.clear() - originalQueue.clear() - }, - forEach(callback) { - batch[syncEffectChannelSymbol]!.forEach(callback) - originalQueue.forEach(callback) - }, - get size() { - return batch[syncEffectChannelSymbol]!.size + originalQueue.size - }, - } - } - return batch[syncEffectChannelSymbol]! -} - const getAtomStateMap = new WeakMap() /** diff --git a/tests/vanilla/store.test.tsx b/tests/vanilla/store.test.tsx index 425063617e..6ee45b2680 100644 --- a/tests/vanilla/store.test.tsx +++ b/tests/vanilla/store.test.tsx @@ -1108,3 +1108,21 @@ it('recomputes dependents of unmounted atoms', () => { store.set(w) expect(bRead).not.toHaveBeenCalled() }) + +it('recomputes all changed atom dependents together', async () => { + const a = atom([0]) + const b = atom([0]) + const a0 = atom((get) => get(a)[0]!) + const b0 = atom((get) => get(b)[0]!) + const a0b0 = atom((get) => [get(a0), get(b0)]) + const w = atom(null, (_, set) => { + set(a, [0]) + set(b, [1]) + }) + const store = createStore() + store.sub(a0b0, () => {}) + store.set(w) + expect(store.get(a0)).toBe(0) + expect(store.get(b0)).toBe(1) + expect(store.get(a0b0)).toEqual([0, 1]) +})