diff --git a/tests/test_client.py b/tests/test_client.py index ebbfa2e..204c256 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,7 +3,13 @@ import arxiv from datetime import datetime, timedelta from pytest import approx +from requests import Response +def empty_response(code: int) -> Response: + r = Response() + r.status_code = code + r._content = b'' + return r class TestClient(unittest.TestCase): def test_invalid_format_id(self): @@ -90,10 +96,10 @@ def test_no_duplicates(self): self.assertFalse(r.entry_id in ids) ids.add(r.entry_id) + @patch('requests.Session.get', return_value=empty_response(500)) @patch("time.sleep", return_value=None) - def test_retry(self, patched_time_sleep): - broken_client = TestClient.get_code_client(500) - + def test_retry(self, mock_sleep, mock_get): + broken_client = arxiv.Client() def broken_get(): search = arxiv.Search(query="quantum") return next(broken_client.results(search)) @@ -109,77 +115,73 @@ def broken_get(): self.assertEqual(e.status, 500) self.assertEqual(e.retry, broken_client.num_retries) + @patch('requests.Session.get', return_value=empty_response(200)) @patch("time.sleep", return_value=None) - def test_sleep_standard(self, patched_time_sleep): - client = TestClient.get_code_client(200) + def test_sleep_standard(self, mock_sleep, mock_get): + client = arxiv.Client() url = client._format_url(arxiv.Search(query="quantum"), 0, 1) # A client should sleep until delay_seconds have passed. client._parse_feed(url) - patched_time_sleep.assert_not_called() + mock_sleep.assert_not_called() # Overwrite _last_request_dt to minimize flakiness: different # environments will have different page fetch times. client._last_request_dt = datetime.now() client._parse_feed(url) - patched_time_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3)) + mock_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3)) + @patch('requests.Session.get', return_value=empty_response(200)) @patch("time.sleep", return_value=None) - def test_sleep_multiple_requests(self, patched_time_sleep): - client = TestClient.get_code_client(200) + def test_sleep_multiple_requests(self, mock_sleep, mock_get): + client = arxiv.Client() url1 = client._format_url(arxiv.Search(query="quantum"), 0, 1) url2 = client._format_url(arxiv.Search(query="testing"), 0, 1) # Rate limiting is URL-independent; expect same behavior as in # `test_sleep_standard`. client._parse_feed(url1) - patched_time_sleep.assert_not_called() + mock_sleep.assert_not_called() client._last_request_dt = datetime.now() client._parse_feed(url2) - patched_time_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3)) + mock_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3)) + @patch('requests.Session.get', return_value=empty_response(200)) @patch("time.sleep", return_value=None) - def test_sleep_elapsed(self, patched_time_sleep): - client = TestClient.get_code_client(200) + def test_sleep_elapsed(self, mock_sleep, mock_get): + client = arxiv.Client() url = client._format_url(arxiv.Search(query="quantum"), 0, 1) # If _last_request_dt is less than delay_seconds ago, sleep. client._last_request_dt = datetime.now() - timedelta(seconds=client.delay_seconds - 1) client._parse_feed(url) - patched_time_sleep.assert_called_once() - patched_time_sleep.reset_mock() + mock_sleep.assert_called_once() + mock_sleep.reset_mock() # If _last_request_dt is at least delay_seconds ago, don't sleep. client._last_request_dt = datetime.now() - timedelta(seconds=client.delay_seconds) client._parse_feed(url) - patched_time_sleep.assert_not_called() + mock_sleep.assert_not_called() + @patch('requests.Session.get', return_value=empty_response(200)) @patch("time.sleep", return_value=None) - def test_sleep_zero_delay(self, patched_time_sleep): - client = TestClient.get_code_client(code=200, delay_seconds=0) + def test_sleep_zero_delay(self, mock_sleep, mock_get): + client = arxiv.Client(delay_seconds=0) url = client._format_url(arxiv.Search(query="quantum"), 0, 1) client._parse_feed(url) client._parse_feed(url) - patched_time_sleep.assert_not_called() + mock_sleep.assert_not_called() + @patch('requests.Session.get', return_value=empty_response(500)) @patch("time.sleep", return_value=None) - def test_sleep_between_errors(self, patched_time_sleep): - client = TestClient.get_code_client(500) + def test_sleep_between_errors(self, mock_sleep, mock_get): + client = arxiv.Client() url = client._format_url(arxiv.Search(query="quantum"), 0, 1) try: client._parse_feed(url) except arxiv.HTTPError: pass # Should sleep between retries. - patched_time_sleep.assert_called() - self.assertEqual(patched_time_sleep.call_count, client.num_retries) - patched_time_sleep.assert_has_calls( + mock_sleep.assert_called() + self.assertEqual(mock_sleep.call_count, client.num_retries) + mock_sleep.assert_has_calls( [ call(approx(client.delay_seconds, abs=1e-2)), ] * client.num_retries ) - - def get_code_client(code: int, delay_seconds=0.1, num_retries=3) -> arxiv.Client: - """ - get_code_client returns an arxiv.Cient with HTTP requests routed to - httpstat.us. - """ - client = arxiv.Client(delay_seconds=delay_seconds, num_retries=num_retries) - client.query_url_format = "https://teapot.fly.dev/{}?".format(code) + "{}" - return client