From 8992d0986b0361d7676cfc3acd62be23b7fac833 Mon Sep 17 00:00:00 2001 From: Erjia Guan Date: Tue, 15 Mar 2022 10:53:53 -0700 Subject: [PATCH] Fix Lint for Bucketizer (#296) Summary: Fix mypy Error introduced by `max_token_bucketize` Pull Request resolved: https://github.com/pytorch/data/pull/296 Reviewed By: NivekT Differential Revision: D34895420 Pulled By: ejguan fbshipit-source-id: 531210ea9e008320998231df8c3569d8ef7696fc --- torchdata/datapipes/iter/transform/bucketbatcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchdata/datapipes/iter/transform/bucketbatcher.py b/torchdata/datapipes/iter/transform/bucketbatcher.py index 9e0b8b596..e2e04b002 100644 --- a/torchdata/datapipes/iter/transform/bucketbatcher.py +++ b/torchdata/datapipes/iter/transform/bucketbatcher.py @@ -211,7 +211,7 @@ def __iter__(self) -> Iterator[DataChunk[T_co]]: if len(buffer) == self.buffer_size: length, token = heapq.heappop(buffer) if batch_size + length > self.max_token_count: - yield batch + yield DataChunk(batch) batch = [] batch_size = 0 batch.append(token) @@ -219,10 +219,10 @@ def __iter__(self) -> Iterator[DataChunk[T_co]]: while buffer: length, token = heapq.heappop(buffer) if batch_size + length > self.max_token_count: - yield batch + yield DataChunk(batch) batch = [] batch_size = 0 batch.append(token) batch_size += length if batch: - yield batch + yield DataChunk(batch)