diff --git a/packages/enzyme-test-suite/test/ShallowWrapper-spec.jsx b/packages/enzyme-test-suite/test/ShallowWrapper-spec.jsx
index 22e14d897..fcad363fe 100644
--- a/packages/enzyme-test-suite/test/ShallowWrapper-spec.jsx
+++ b/packages/enzyme-test-suite/test/ShallowWrapper-spec.jsx
@@ -4666,4 +4666,75 @@ describe('shallow', () => {
});
});
});
+ describe('setState through a props method', () => {
+ it('should be able to get the latest state value', () => {
+ const Child = props => ;
+ class App extends React.Component {
+ constructor(props) {
+ super(props);
+ this.state = {
+ count: 0,
+ };
+ }
+ onIncrement() {
+ this.setState({
+ count: this.state.count + 1,
+ });
+ }
+ render() {
+ return (
+
+
this.onIncrement()} />
+ {this.state.count}
+
+ );
+ }
+ }
+ const wrapper = shallow();
+ const p = wrapper.find('p');
+ expect(wrapper.find('p').text()).to.equal('0');
+ wrapper.find(Child).prop('onClick')();
+ // this is still 0 because the wrapper won't be updated
+ expect(p.text()).to.equal('0');
+ expect(wrapper.find('p').text()).to.equal('1');
+ });
+ });
+ describe('setState through a props method in async', () => {
+ it('should be able to get the latest state value', (done) => {
+ const Child = props => ;
+ let App;
+ const promise = new Promise((resolve) => {
+ App = class extends React.Component {
+ constructor(props) {
+ super(props);
+ this.state = {
+ count: 0,
+ };
+ }
+ onIncrement() {
+ setTimeout(() => {
+ this.setState({
+ count: this.state.count + 1,
+ }, resolve);
+ });
+ }
+ render() {
+ return (
+
+
this.onIncrement()} />
+ {this.state.count}
+
+ );
+ }
+ };
+ });
+ const wrapper = shallow();
+ promise.then(() => {
+ expect(wrapper.find('p').text()).to.equal('1');
+ done();
+ });
+ expect(wrapper.find('p').text()).to.equal('0');
+ wrapper.find(Child).prop('onClick')();
+ });
+ });
});
diff --git a/packages/enzyme/src/ShallowWrapper.js b/packages/enzyme/src/ShallowWrapper.js
index 2144bc630..f6f049433 100644
--- a/packages/enzyme/src/ShallowWrapper.js
+++ b/packages/enzyme/src/ShallowWrapper.js
@@ -150,11 +150,17 @@ class ShallowWrapper {
return this[ROOT];
}
+ getRootNodeInternal() {
+ return this[ROOT][NODE];
+ }
getNodeInternal() {
if (this.length !== 1) {
throw new Error('ShallowWrapper::getNode() can only be called when wrapping one node');
}
+ if (this[ROOT] === this) {
+ this.update();
+ }
return this[NODE];
}
@@ -167,7 +173,7 @@ class ShallowWrapper {
if (this.length !== 1) {
throw new Error('ShallowWrapper::getElement() can only be called when wrapping one node');
}
- return getAdapter(this[OPTIONS]).nodeToElement(this[NODE]);
+ return getAdapter(this[OPTIONS]).nodeToElement(this.getNodeInternal());
}
/**
@@ -176,7 +182,7 @@ class ShallowWrapper {
* @return {Array}
*/
getElements() {
- return this[NODES].map(getAdapter(this[OPTIONS]).nodeToElement);
+ return this.getNodesInternal().map(getAdapter(this[OPTIONS]).nodeToElement);
}
// eslint-disable-next-line class-methods-use-this
@@ -185,6 +191,9 @@ class ShallowWrapper {
}
getNodesInternal() {
+ if (this[ROOT] === this && this.length === 1) {
+ this.update();
+ }
return this[NODES];
}
@@ -225,9 +234,10 @@ class ShallowWrapper {
if (this[ROOT] !== this) {
throw new Error('ShallowWrapper::update() can only be called on the root');
}
- this.single('update', () => {
- privateSetNodes(this, getRootNode(this[RENDERER].getNode()));
- });
+ if (this.length !== 1) {
+ throw new Error('ShallowWrapper::update() can only be called when wrapping one node');
+ }
+ privateSetNodes(this, getRootNode(this[RENDERER].getNode()));
return this;
}
@@ -793,7 +803,7 @@ class ShallowWrapper {
* @returns {ShallowWrapper}
*/
parents(selector) {
- const allParents = this.wrap(this.single('parents', n => parentsOfNode(n, this[ROOT][NODE])));
+ const allParents = this.wrap(this.single('parents', n => parentsOfNode(n, this.getRootNodeInternal())));
return selector ? allParents.filter(selector) : allParents;
}
@@ -1188,7 +1198,7 @@ if (ITERATOR_SYMBOL) {
Object.defineProperty(ShallowWrapper.prototype, ITERATOR_SYMBOL, {
configurable: true,
value: function iterator() {
- const iter = this[NODES][ITERATOR_SYMBOL]();
+ const iter = this.getNodesInternal()[ITERATOR_SYMBOL]();
const adapter = getAdapter(this[OPTIONS]);
return {
next() {