diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b52b8ea7e1..fbfeb9aec2 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -959,20 +959,17 @@ def scan( Returns: A DataScan based on the table's current metadata. """ - scan = DataScan( + return DataScan( table_metadata=self.metadata, io=self.io, row_filter=row_filter, selected_fields=selected_fields, case_sensitive=case_sensitive, snapshot_id=snapshot_id, + ref_name=ref_name, options=options, limit=limit, ) - if ref_name is not None: - scan = scan.use_ref(ref_name) - - return scan @property def format_version(self) -> TableVersion: @@ -1438,6 +1435,7 @@ class TableScan(ABC): selected_fields: Tuple[str, ...] case_sensitive: bool snapshot_id: Optional[int] + ref_name: Optional[str] options: Properties limit: Optional[int] @@ -1449,6 +1447,7 @@ def __init__( selected_fields: Tuple[str, ...] = ("*",), case_sensitive: bool = True, snapshot_id: Optional[int] = None, + ref_name: Optional[str] = None, options: Properties = EMPTY_DICT, limit: Optional[int] = None, ): @@ -1458,12 +1457,20 @@ def __init__( self.selected_fields = selected_fields self.case_sensitive = case_sensitive self.snapshot_id = snapshot_id + self.ref_name = ref_name self.options = options self.limit = limit def snapshot(self) -> Optional[Snapshot]: + if self.snapshot_id and self.ref_name is not None: + raise ValueError("Cannot specify both snapshot_id and ref_name.") if self.snapshot_id: return self.table_metadata.snapshot_by_id(self.snapshot_id) + if self.ref_name is not None: + if snapshot := self.table_metadata.snapshot_by_name(self.ref_name): + return snapshot + else: + raise ValueError(f"Cannot scan unknown ref={self.ref_name}") return self.table_metadata.current_snapshot() def projection(self) -> Schema: @@ -1503,12 +1510,7 @@ def update(self: S, **overrides: Any) -> S: return type(self)(**{**self.__dict__, **overrides}) def use_ref(self: S, name: str) -> S: - if self.snapshot_id: - raise ValueError(f"Cannot override ref, already set snapshot id={self.snapshot_id}") - if snapshot := self.table_metadata.snapshot_by_name(name): - return self.update(snapshot_id=snapshot.snapshot_id) - - raise ValueError(f"Cannot scan unknown ref={name}") + return self.update(ref_name=name) def select(self: S, *field_names: str) -> S: if "*" in self.selected_fields: diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 69bbab527e..a938ec12bf 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -318,14 +318,23 @@ def test_table_scan_row_filter(table_v2: Table) -> None: def test_table_scan_ref(table_v2: Table) -> None: scan = table_v2.scan() - assert scan.use_ref("test").snapshot_id == 3051729675574597004 + assert scan.use_ref("test").ref_name == "test" + + +def test_table_scan_ref_and_snapshot_id(table_v2: Table) -> None: + scan = table_v2.scan(snapshot_id=123) + + with pytest.raises(ValueError) as exc_info: + _ = scan.use_ref("test").snapshot() + + assert "Cannot specify both snapshot_id and ref_name" in str(exc_info.value) def test_table_scan_ref_does_not_exists(table_v2: Table) -> None: scan = table_v2.scan() with pytest.raises(ValueError) as exc_info: - _ = scan.use_ref("boom") + _ = scan.use_ref("boom").snapshot() assert "Cannot scan unknown ref=boom" in str(exc_info.value)