From 79fae3b0a15be30d35131420f030c9a31338b357 Mon Sep 17 00:00:00 2001
From: Sam Gross <colesbury@gmail.com>
Date: Mon, 3 Jun 2024 18:47:34 -0400
Subject: [PATCH] [3.13] gh-117657: Fix itertools.count thread safety
 (GH-119268) (#120007)

Fix itertools.count in free-threading mode
(cherry picked from commit 87939bd5790accea77c5a81093f16f28d3f0b429)

Co-authored-by: Arnon Yaari <wiggin15@yahoo.com>
---
 Lib/test/test_itertools.py                 | 24 ++++++++++++-
 Modules/itertoolsmodule.c                  | 40 +++++++++++++++++-----
 Tools/tsan/suppressions_free_threading.txt |  1 -
 3 files changed, 54 insertions(+), 11 deletions(-)

diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index e243da309f03d8..2c92d880c10cb3 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -644,7 +644,7 @@ def test_count(self):
         count(1, maxsize+5); sys.exc_info()
 
     @pickle_deprecated
-    def test_count_with_stride(self):
+    def test_count_with_step(self):
         self.assertEqual(lzip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)])
         self.assertEqual(lzip('abc',count(start=2,step=3)),
                          [('a', 2), ('b', 5), ('c', 8)])
@@ -699,6 +699,28 @@ def test_count_with_stride(self):
                 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
                     self.pickletest(proto, count(i, j))
 
+    @threading_helper.requires_working_threading()
+    def test_count_threading(self, step=1):
+        # this test verifies multithreading consistency, which is
+        # mostly for testing builds without GIL, but nice to test anyway
+        count_to = 10_000
+        num_threads = 10
+        c = count(step=step)
+        def counting_thread():
+            for i in range(count_to):
+                next(c)
+        threads = []
+        for i in range(num_threads):
+            thread = threading.Thread(target=counting_thread)
+            thread.start()
+            threads.append(thread)
+        for thread in threads:
+            thread.join()
+        self.assertEqual(next(c), count_to * num_threads * step)
+
+    def test_count_with_step_threading(self):
+        self.test_count_threading(step=5)
+
     def test_cycle(self):
         self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
         self.assertEqual(list(cycle('')), [])
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index 8641c2f87e6db2..0d6ff20489aa2c 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -1,13 +1,14 @@
 #include "Python.h"
-#include "pycore_call.h"          // _PyObject_CallNoArgs()
-#include "pycore_ceval.h"         // _PyEval_GetBuiltin()
-#include "pycore_long.h"          // _PyLong_GetZero()
-#include "pycore_moduleobject.h"  // _PyModule_GetState()
-#include "pycore_typeobject.h"    // _PyType_GetModuleState()
-#include "pycore_object.h"        // _PyObject_GC_TRACK()
-#include "pycore_tuple.h"         // _PyTuple_ITEMS()
+#include "pycore_call.h"              // _PyObject_CallNoArgs()
+#include "pycore_ceval.h"             // _PyEval_GetBuiltin()
+#include "pycore_critical_section.h"  // Py_BEGIN_CRITICAL_SECTION()
+#include "pycore_long.h"              // _PyLong_GetZero()
+#include "pycore_moduleobject.h"      // _PyModule_GetState()
+#include "pycore_typeobject.h"        // _PyType_GetModuleState()
+#include "pycore_object.h"            // _PyObject_GC_TRACK()
+#include "pycore_tuple.h"             // _PyTuple_ITEMS()
 
-#include <stddef.h>               // offsetof()
+#include <stddef.h>                   // offsetof()
 
 /* Itertools module written and maintained
    by Raymond D. Hettinger <python@rcn.com>
@@ -4037,7 +4038,7 @@ fast_mode:  when cnt an integer < PY_SSIZE_T_MAX and no step is specified.
 
     assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1));
     Advances with:  cnt += 1
-    When count hits Y_SSIZE_T_MAX, switch to slow_mode.
+    When count hits PY_SSIZE_T_MAX, switch to slow_mode.
 
 slow_mode:  when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float.
 
@@ -4186,9 +4187,30 @@ count_nextlong(countobject *lz)
 static PyObject *
 count_next(countobject *lz)
 {
+#ifndef Py_GIL_DISABLED
     if (lz->cnt == PY_SSIZE_T_MAX)
         return count_nextlong(lz);
     return PyLong_FromSsize_t(lz->cnt++);
+#else
+    // free-threading version
+    // fast mode uses compare-exchange loop
+    // slow mode uses a critical section
+    PyObject *returned;
+    Py_ssize_t cnt;
+
+    cnt = _Py_atomic_load_ssize_relaxed(&lz->cnt);
+    for (;;) {
+        if (cnt == PY_SSIZE_T_MAX) {
+            Py_BEGIN_CRITICAL_SECTION(lz);
+            returned = count_nextlong(lz);
+            Py_END_CRITICAL_SECTION();
+            return returned;
+        }
+        if (_Py_atomic_compare_exchange_ssize(&lz->cnt, &cnt, cnt + 1)) {
+            return PyLong_FromSsize_t(cnt);
+        }
+    }
+#endif
 }
 
 static PyObject *
diff --git a/Tools/tsan/suppressions_free_threading.txt b/Tools/tsan/suppressions_free_threading.txt
index 0bb0183147b722..d5fcac61f0db04 100644
--- a/Tools/tsan/suppressions_free_threading.txt
+++ b/Tools/tsan/suppressions_free_threading.txt
@@ -47,7 +47,6 @@ race_top:_PyImport_AcquireLock
 race_top:_Py_dict_lookup_threadsafe
 race_top:_imp_release_lock
 race_top:_multiprocessing_SemLock_acquire_impl
-race_top:count_next
 race_top:dictiter_new
 race_top:dictresize
 race_top:insert_to_emptydict