Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow inserting subschemas #3041

Merged
merged 19 commits into from
Nov 7, 2024
Merged
18 changes: 0 additions & 18 deletions python/python/tests/test_balanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,21 +271,3 @@ def test_unsupported(balanced_dataset, big_val):
balanced_dataset.merge_insert("idx").when_not_matched_insert_all().execute(
make_table(0, 1, big_val)
)


# TODO: Once https://github.com/lancedb/lance/pull/3041 merges we will
# want to test partial appends. We need to make sure an append of
# non-blob data is supported. In order to do this we need to make
# sure a blob tx is created that marks the row ids as used so that
# the two row id sequences stay in sync.
#
# def test_one_sided_append(balanced_dataset, tmp_path):
# # Write new data, but only to the idx column
# ds = lance.write_dataset(
# pa.table({"idx": pa.array(range(128, 256), pa.uint64())}),
# tmp_path / "test_ds",
# max_bytes_per_file=32 * 1024 * 1024,
# mode="append",
# )

# print(ds.to_table())
7 changes: 7 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def test_dataset_append(tmp_path: Path):
with pytest.raises(OSError):
lance.write_dataset(table2, base_dir, mode="append")

# But we can append subschemas
table3 = pa.Table.from_pydict({"colA": [4, 5, 6]})
dataset = lance.write_dataset(table3, base_dir, mode="append")
assert dataset.to_table() == pa.table(
{"colA": [1, 2, 3, 4, 5, 6], "colB": [4, 5, 6, None, None, None]}
)


def test_dataset_from_record_batch_iterable(tmp_path: Path):
base_dir = tmp_path / "test"
Expand Down
2 changes: 1 addition & 1 deletion python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ impl Dataset {
let dict = PyDict::new(py);
let schema = self_.ds.schema();

let idx_schema = schema.project_by_ids(idx.fields.as_slice());
let idx_schema = schema.project_by_ids(idx.fields.as_slice(), true);

let is_vector = idx_schema
.fields
Expand Down
2 changes: 1 addition & 1 deletion python/src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl PrettyPrintableFragment {
.files
.iter()
.map(|file| {
let schema = schema.project_by_ids(&file.fields);
let schema = schema.project_by_ids(&file.fields, false);
PrettyPrintableDataFile {
path: file.path.clone(),
fields: file.fields.clone(),
Expand Down
5 changes: 4 additions & 1 deletion rust/lance-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ mod field;
mod schema;

use crate::{Error, Result};
pub use field::{Encoding, Field, NullabilityComparison, SchemaCompareOptions, StorageClass};
pub use field::{
Encoding, Field, NullabilityComparison, SchemaCompareOptions, StorageClass,
LANCE_STORAGE_CLASS_SCHEMA_META_KEY,
};
pub use schema::Schema;

pub const COMPRESSION_META_KEY: &str = "lance-encoding:compression";
Expand Down
127 changes: 45 additions & 82 deletions rust/lance-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use std::{
cmp::max,
collections::{HashMap, HashSet},
collections::HashMap,
fmt::{self, Display},
str::FromStr,
sync::Arc,
Expand All @@ -23,7 +23,10 @@ use deepsize::DeepSizeOf;
use lance_arrow::{bfloat16::ARROW_EXT_NAME_KEY, *};
use snafu::{location, Location};

use super::{Dictionary, LogicalType};
use super::{
schema::{compare_fields, explain_fields_difference},
Dictionary, LogicalType,
};
use crate::{Error, Result};

pub const LANCE_STORAGE_CLASS_SCHEMA_META_KEY: &str = "lance-schema:storage-class";
Expand All @@ -49,6 +52,14 @@ pub struct SchemaCompareOptions {
pub compare_field_ids: bool,
/// Should nullability be compared (default Strict)
pub compare_nullability: NullabilityComparison,
/// Allow fields in the expected schema to be missing from the schema being tested if
/// they are nullable (default false)
///
/// Fields in the schema being tested must always be present in the expected schema
/// regardless of this flag.
pub allow_missing_if_nullable: bool,
/// Allow out of order fields (default false)
pub ignore_field_order: bool,
}
/// Encoding enum.
#[derive(Debug, Clone, PartialEq, Eq, DeepSizeOf)]
Expand Down Expand Up @@ -151,7 +162,7 @@ impl Field {
self.storage_class
}

fn explain_differences(
pub(crate) fn explain_differences(
&self,
expected: &Self,
options: &SchemaCompareOptions,
Expand Down Expand Up @@ -210,61 +221,19 @@ impl Field {
self_name
));
}
if self.children.len() != expected.children.len()
|| !self
.children
.iter()
.zip(expected.children.iter())
.all(|(child, expected)| child.name == expected.name)
{
let self_children = self
.children
.iter()
.map(|child| child.name.clone())
.collect::<HashSet<_>>();
let expected_children = expected
.children
.iter()
.map(|child| child.name.clone())
.collect::<HashSet<_>>();
let missing = expected_children
.difference(&self_children)
.cloned()
.collect::<Vec<_>>();
let unexpected = self_children
.difference(&expected_children)
.cloned()
.collect::<Vec<_>>();
if missing.is_empty() && unexpected.is_empty() {
differences.push(format!(
"`{}` field order mismatch, expected [{}] but was [{}]",
self_name,
expected
.children
.iter()
.map(|child| child.name.clone())
.collect::<Vec<_>>()
.join(", "),
self.children
.iter()
.map(|child| child.name.clone())
.collect::<Vec<_>>()
.join(", "),
));
} else {
differences.push(format!(
"`{}` had mismatched children, missing=[{}] unexpected=[{}]",
self_name,
missing.join(", "),
unexpected.join(", ")
));
}
} else {
differences.extend(self.children.iter().zip(expected.children.iter()).flat_map(
|(child, expected_child)| {
child.explain_differences(expected_child, options, Some(&self_name))
},
));
let children_differences = explain_fields_difference(
&self.children,
&expected.children,
options,
Some(&self_name),
);
if !children_differences.is_empty() {
let children_differences = format!(
"`{}` had mismatched children: {}",
self_name,
children_differences.join(", ")
);
differences.push(children_differences);
}
differences
}
Expand Down Expand Up @@ -295,22 +264,13 @@ impl Field {
}

pub fn compare_with_options(&self, expected: &Self, options: &SchemaCompareOptions) -> bool {
if self.children.len() != expected.children.len() {
false
} else {
self.name == expected.name
&& self.logical_type == expected.logical_type
&& Self::compare_nullability(expected.nullable, self.nullable, options)
&& self.children.len() == expected.children.len()
&& self
.children
.iter()
.zip(&expected.children)
.all(|(left, right)| left.compare_with_options(right, options))
&& (!options.compare_field_ids || self.id == expected.id)
&& (!options.compare_dictionary || self.dictionary == expected.dictionary)
&& (!options.compare_metadata || self.metadata == expected.metadata)
}
self.name == expected.name
&& self.logical_type == expected.logical_type
&& Self::compare_nullability(expected.nullable, self.nullable, options)
&& compare_fields(&self.children, &expected.children, options)
&& (!options.compare_field_ids || self.id == expected.id)
&& (!options.compare_dictionary || self.dictionary == expected.dictionary)
&& (!options.compare_metadata || self.metadata == expected.metadata)
}

pub fn extension_name(&self) -> Option<&str> {
Expand Down Expand Up @@ -476,13 +436,13 @@ impl Field {
///
/// If the ids are `[2]`, then this will include the parent `0` and the
/// child `3`.
pub(crate) fn project_by_ids(&self, ids: &[i32]) -> Option<Self> {
wjones127 marked this conversation as resolved.
Show resolved Hide resolved
pub(crate) fn project_by_ids(&self, ids: &[i32], include_all_children: bool) -> Option<Self> {
let children = self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor nit: I'm guessing the optimizer catches this but it might be faster to only calculate children if we need it...

pub(crate) fn project_by_ids(&self, ids: &[i32], include_all_children: bool) -> Option<Self> {
  if !ids.contains(&self.id) {
      return None;
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually don't want to early return, because even if a field isn't selected, we want to check if it has children that are.

.children
.iter()
.filter_map(|c| c.project_by_ids(ids))
.filter_map(|c| c.project_by_ids(ids, include_all_children))
.collect::<Vec<_>>();
if ids.contains(&self.id) {
if ids.contains(&self.id) && (children.is_empty() || include_all_children) {
Some(self.clone())
} else if !children.is_empty() {
Some(Self {
Expand Down Expand Up @@ -1177,7 +1137,10 @@ mod tests {
.unwrap();
assert_eq!(
wrong_child.explain_difference(&expected, &opts),
Some("`a.b` should have nullable=true but nullable=false".to_string())
Some(
"`a` had mismatched children: `a.b` should have nullable=true but nullable=false"
.to_string()
)
);

let mismatched_children: Field = ArrowField::new(
Expand All @@ -1192,13 +1155,13 @@ mod tests {
.unwrap();
assert_eq!(
mismatched_children.explain_difference(&expected, &opts),
Some("`a` had mismatched children, missing=[c] unexpected=[d]".to_string())
Some("`a` had mismatched children: fields did not match, missing=[a.c], unexpected=[a.d]".to_string())
);

let reordered_children: Field = ArrowField::new(
"a",
DataType::Struct(Fields::from(vec![
ArrowField::new("c", DataType::Int32, false),
ArrowField::new("c", DataType::Int32, true),
ArrowField::new("b", DataType::Int32, true),
])),
true,
Expand All @@ -1207,7 +1170,7 @@ mod tests {
.unwrap();
assert_eq!(
reordered_children.explain_difference(&expected, &opts),
Some("`a` field order mismatch, expected [b, c] but was [c, b]".to_string())
Some("`a` had mismatched children: fields in different order, expected: [b, c], actual: [c, b]".to_string())
);

let multiple_wrongs: Field = ArrowField::new(
Expand All @@ -1223,7 +1186,7 @@ mod tests {
assert_eq!(
multiple_wrongs.explain_difference(&expected, &opts),
Some(
"expected name 'a' but name was 'c', `c.c` should have type int32 but type was float"
"expected name 'a' but name was 'c', `c` had mismatched children: `c.c` should have type int32 but type was float"
.to_string()
)
);
Expand Down
Loading
Loading