diff --git a/torchdata/datapipes/iter/transform/bucketbatcher.py b/torchdata/datapipes/iter/transform/bucketbatcher.py index 9e0b8b596..e2e04b002 100644 --- a/torchdata/datapipes/iter/transform/bucketbatcher.py +++ b/torchdata/datapipes/iter/transform/bucketbatcher.py @@ -211,7 +211,7 @@ def __iter__(self) -> Iterator[DataChunk[T_co]]: if len(buffer) == self.buffer_size: length, token = heapq.heappop(buffer) if batch_size + length > self.max_token_count: - yield batch + yield DataChunk(batch) batch = [] batch_size = 0 batch.append(token) @@ -219,10 +219,10 @@ def __iter__(self) -> Iterator[DataChunk[T_co]]: while buffer: length, token = heapq.heappop(buffer) if batch_size + length > self.max_token_count: - yield batch + yield DataChunk(batch) batch = [] batch_size = 0 batch.append(token) batch_size += length if batch: - yield batch + yield DataChunk(batch)