diff --git a/Python/ceval.c b/Python/ceval.c
index a63b80395dcfcf..75984bbf1ba95d 100644
--- a/Python/ceval.c
+++ b/Python/ceval.c
@@ -5619,6 +5619,30 @@ initialize_locals(PyThreadState *tstate, PyFrameConstructor *con,
     return 0;
 
 fail: /* Jump here from prelude on failure */
+    if (steal_args) {
+        // If we failed to initialize locals, make sure the caller still own all the
+        // arguments that were on the stack. We need to increment the reference count
+        // of everything we copied (everything in localsplus) that came from the stack
+        // (everything that is present in the "args" array).
+        Py_ssize_t kwcount = kwnames != NULL ? PyTuple_GET_SIZE(kwnames) : 0;
+        for (Py_ssize_t k=0; k < total_args; k++) {
+            PyObject* arg = localsplus[k];
+            for (Py_ssize_t j=0; j < argcount + kwcount; j++) {
+                if (args[j] == arg) {
+                    Py_XINCREF(arg);
+                    break;
+                }
+            }
+        }
+        // Restore all the **kwargs we placed into the kwargs dictionary
+        if (kwdict) {
+            PyObject *key, *value;
+            Py_ssize_t pos = 0;
+            while (PyDict_Next(kwdict, &pos, &key, &value)) {
+                Py_INCREF(value);
+            }
+        }
+    }
     return -1;
 
 }
@@ -5683,16 +5707,6 @@ _PyEvalFramePushAndInit(PyThreadState *tstate, PyFrameConstructor *con,
     }
     PyObject **localsarray = _PyFrame_GetLocalsArray(frame);
     if (initialize_locals(tstate, con, localsarray, args, argcount, kwnames, steal_args)) {
-        if (steal_args) {
-            // If we failed to initialize locals, make sure the caller still own all the
-            // arguments. Notice that we only need to increase the reference count of the
-            // *valid* arguments (i.e. the ones that fit into the frame).
-            PyCodeObject *co = (PyCodeObject*)con->fc_code;
-            const size_t total_args = co->co_argcount + co->co_kwonlyargcount;
-            for (size_t i = 0; i < Py_MIN(argcount, total_args); i++) {
-                Py_XINCREF(frame->localsplus[i]);
-            }
-        }
         _PyFrame_Clear(frame, 0);
         return NULL;
     }