Skip to content

Commit

Permalink
much better tests + len support
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Jan 14, 2025
1 parent 76f89e3 commit a5a7e4a
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 9 deletions.
70 changes: 70 additions & 0 deletions tests/trace/test_weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,76 @@ def test_get_calls_complete(client):
assert call1.inputs["s"] == call2.inputs["s"]


def test_get_calls_len(client):
for i in range(10):
client.create_call("x", {"a": i})

# test len first
calls = client.get_calls()
assert len(calls) == 10

calls = client.get_calls(limit=5)
assert len(calls) == 5

calls = client.get_calls(limit=5, offset=5)
assert len(calls) == 5

calls = client.get_calls(offset=10)
assert len(calls) == 0

calls = client.get_calls(offset=10, limit=10)
assert len(calls) == 0

with pytest.raises(ValueError):
client.get_calls(limit=-1)

with pytest.raises(ValueError):
client.get_calls(limit=0)

with pytest.raises(ValueError):
client.get_calls(offset=-1)


def test_get_calls_limit_offset(client):
for i in range(10):
client.create_call("x", {"a": i})

calls = client.get_calls(limit=3)
assert len(calls) == 3
for i, call in enumerate(calls):
assert call.inputs["a"] == i

calls = client.get_calls(limit=5, offset=5)
assert len(calls) == 5

for i, call in enumerate(calls):
assert call.inputs["a"] == i + 5

calls = client.get_calls(offset=9)
assert len(calls) == 1
assert calls[0].inputs["a"] == 9

# now test indexing
calls = client.get_calls()
assert calls[0].inputs["a"] == 0
assert calls[1].inputs["a"] == 1
assert calls[2].inputs["a"] == 2
assert calls[3].inputs["a"] == 3
assert calls[4].inputs["a"] == 4

calls = client.get_calls(offset=5)
assert calls[0].inputs["a"] == 5
assert calls[1].inputs["a"] == 6
assert calls[2].inputs["a"] == 7
assert calls[3].inputs["a"] == 8
assert calls[4].inputs["a"] == 9

# slicing
calls = client.get_calls(offset=5)
for i, call in enumerate(calls[2:]):
assert call.inputs["a"] == 7 + i


def test_calls_delete(client):
call0 = client.create_call("x", {"a": 5, "b": 10})
call0_child1 = client.create_call("x", {"a": 5, "b": 11}, call0)
Expand Down
23 changes: 14 additions & 9 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(

if page_size <= 0:
raise ValueError("page_size must be greater than 0")
if limit is not None and limit < 0:
if limit is not None and limit <= 0:
raise ValueError("limit must be greater than 0")
if offset is not None and offset < 0:
raise ValueError("offset must be greater than or equal to 0")
Expand All @@ -158,12 +158,12 @@ def _get_one(self, index: int) -> T | R:
if index < 0:
raise IndexError("Negative indexing not supported")

if self.limit is not None and index >= self.limit + (self.offset or 0):
raise IndexError(f"Index {index} out of range")

if self.offset is not None:
index += self.offset

if self.limit is not None and index >= self.limit:
raise IndexError(f"Index {index} out of range")

page_index = index // self.page_size
page_offset = index % self.page_size

Expand All @@ -188,17 +188,17 @@ def _get_slice(self, key: slice) -> Iterator[T] | Iterator[R]:
if (step := key.step or 1) < 0:
raise ValueError("Negative step not supported")

# Apply limit if provided
if self.limit is not None:
if stop is None or stop > self.limit:
stop = self.limit

# Apply offset if provided
if self.offset is not None:
start += self.offset
if stop is not None:
stop += self.offset

# Apply limit if provided
if self.limit is not None:
if stop is None or stop > self.limit:
stop = self.limit

i = start
while stop is None or i < stop:
try:
Expand Down Expand Up @@ -278,6 +278,11 @@ def size_func() -> int:
response = server.calls_query_stats(
CallsQueryStatsReq(project_id=project_id, filter=filter)
)
if limit_override is not None:
offset = offset_override or 0
return min(limit_override, response.count - offset)
if offset_override is not None:
return response.count - offset_override
return response.count

return PaginatedIterator(
Expand Down

0 comments on commit a5a7e4a

Please sign in to comment.