Skip to content

Commit

Permalink
Merge pull request #2328 from finos/fix-expression-update-overcalc
Browse files Browse the repository at this point in the history
Fix `update()` with `expressions` overcalc
  • Loading branch information
texodus authored Aug 7, 2023
2 parents f4a9774 + eca996a commit c75d8b0
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 62 deletions.
8 changes: 0 additions & 8 deletions cpp/perspective/src/cpp/context_grouped_pkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,6 @@ t_ctx_grouped_pkey::get_column_dtype(t_uindex idx) const {

void
t_ctx_grouped_pkey::compute_expressions(std::shared_ptr<t_data_table> master,
std::shared_ptr<t_data_table> flattened,
t_expression_vocab& expression_vocab, t_regex_mapping& regex_mapping) {
// Clear the transitional expression tables on the context so they are
// ready for the next update.
Expand All @@ -693,10 +692,6 @@ t_ctx_grouped_pkey::compute_expressions(std::shared_ptr<t_data_table> master,
std::shared_ptr<t_data_table> master_expression_table
= m_expression_tables->m_master;

t_uindex flattened_num_rows = flattened->size();
m_expression_tables->reserve_transitional_table_size(flattened_num_rows);
m_expression_tables->set_transitional_table_size(flattened_num_rows);

// Set the master table to the right size.
t_uindex num_rows = master->size();
master_expression_table->reserve(num_rows);
Expand All @@ -707,9 +702,6 @@ t_ctx_grouped_pkey::compute_expressions(std::shared_ptr<t_data_table> master,
// Compute the expressions on the master table.
expr->compute(
master, master_expression_table, expression_vocab, regex_mapping);

expr->compute(flattened, m_expression_tables->m_flattened,
expression_vocab, regex_mapping);
}
}

Expand Down
8 changes: 0 additions & 8 deletions cpp/perspective/src/cpp/context_one.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,6 @@ t_ctx1::get_trav_depth(t_index idx) const {

void
t_ctx1::compute_expressions(std::shared_ptr<t_data_table> master,
std::shared_ptr<t_data_table> flattened,
t_expression_vocab& expression_vocab, t_regex_mapping& regex_mapping) {
// Clear the transitional expression tables on the context so they are
// ready for the next update.
Expand All @@ -621,10 +620,6 @@ t_ctx1::compute_expressions(std::shared_ptr<t_data_table> master,
std::shared_ptr<t_data_table> master_expression_table
= m_expression_tables->m_master;

t_uindex flattened_num_rows = flattened->size();
m_expression_tables->reserve_transitional_table_size(flattened_num_rows);
m_expression_tables->set_transitional_table_size(flattened_num_rows);

// Set the master table to the right size.
t_uindex num_rows = master->size();
master_expression_table->reserve(num_rows);
Expand All @@ -635,9 +630,6 @@ t_ctx1::compute_expressions(std::shared_ptr<t_data_table> master,
// Compute the expressions on the master table.
expr->compute(
master, master_expression_table, expression_vocab, regex_mapping);

expr->compute(flattened, m_expression_tables->m_flattened,
expression_vocab, regex_mapping);
}
}

Expand Down
8 changes: 0 additions & 8 deletions cpp/perspective/src/cpp/context_two.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,6 @@ t_ctx2::get_column_dtype(t_uindex idx) const {

void
t_ctx2::compute_expressions(std::shared_ptr<t_data_table> master,
std::shared_ptr<t_data_table> flattened,
t_expression_vocab& expression_vocab, t_regex_mapping& regex_mapping) {
// Clear the transitional expression tables on the context so they are
// ready for the next update.
Expand All @@ -1063,10 +1062,6 @@ t_ctx2::compute_expressions(std::shared_ptr<t_data_table> master,
std::shared_ptr<t_data_table> master_expression_table
= m_expression_tables->m_master;

t_uindex flattened_num_rows = flattened->size();
m_expression_tables->reserve_transitional_table_size(flattened_num_rows);
m_expression_tables->set_transitional_table_size(flattened_num_rows);

// Set the master table to the right size.
t_uindex num_rows = master->size();
master_expression_table->reserve(num_rows);
Expand All @@ -1077,9 +1072,6 @@ t_ctx2::compute_expressions(std::shared_ptr<t_data_table> master,
// Compute the expressions on the master table.
expr->compute(
master, master_expression_table, expression_vocab, regex_mapping);

expr->compute(flattened, m_expression_tables->m_flattened,
expression_vocab, regex_mapping);
}
}

Expand Down
8 changes: 0 additions & 8 deletions cpp/perspective/src/cpp/context_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,6 @@ t_ctx0::get_step_delta(t_index bidx, t_index eidx) {

void
t_ctx0::compute_expressions(std::shared_ptr<t_data_table> master,
std::shared_ptr<t_data_table> flattened,
t_expression_vocab& expression_vocab, t_regex_mapping& regex_mapping) {
// Clear the transitional expression tables on the context so they are
// ready for the next update.
Expand All @@ -629,10 +628,6 @@ t_ctx0::compute_expressions(std::shared_ptr<t_data_table> master,
std::shared_ptr<t_data_table> master_expression_table
= m_expression_tables->m_master;

t_uindex flattened_num_rows = flattened->size();
m_expression_tables->reserve_transitional_table_size(flattened_num_rows);
m_expression_tables->set_transitional_table_size(flattened_num_rows);

// Set the master table to the right size.
t_uindex num_rows = master->size();
master_expression_table->reserve(num_rows);
Expand All @@ -643,9 +638,6 @@ t_ctx0::compute_expressions(std::shared_ptr<t_data_table> master,
// Compute the expressions on the master table.
expr->compute(
master, master_expression_table, expression_vocab, regex_mapping);

expr->compute(flattened, m_expression_tables->m_flattened,
expression_vocab, regex_mapping);
}
}

Expand Down
13 changes: 13 additions & 0 deletions cpp/perspective/src/cpp/expression_tables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ t_expression_tables::get_table() const {
return m_master.get();
}

void
t_expression_tables::set_flattened(std::shared_ptr<t_data_table> flattened) {
t_uindex flattened_num_rows = flattened->size();
reserve_transitional_table_size(flattened_num_rows);
set_transitional_table_size(flattened_num_rows);
const t_schema& schema = m_flattened->get_schema();
const std::vector<std::string>& column_names = schema.m_columns;
for (const auto& colname : column_names) {
m_flattened->set_column(
colname, flattened->get_column(colname)->clone());
}
}

void
t_expression_tables::calculate_transitions(
std::shared_ptr<t_data_table> existed) {
Expand Down
55 changes: 43 additions & 12 deletions cpp/perspective/src/cpp/gnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,8 +871,13 @@ t_gnode::_register_context(
ctx->reset();

if (should_update) {
ctx->compute_expressions(m_gstate->get_table(), pkeyed_table,
ctx->compute_expressions(m_gstate->get_table(),
expression_vocab, expression_regex_mapping);
ctx->get_expression_tables()->set_flattened(
m_gstate->get_pkeyed_table(
ctx->get_expression_tables()->m_master->get_schema(),
ctx->get_expression_tables()->m_master));

update_context_from_state<t_ctx2>(ctx, name, pkeyed_table);
}
} break;
Expand All @@ -881,8 +886,12 @@ t_gnode::_register_context(
t_ctx1* ctx = static_cast<t_ctx1*>(ptr_);
ctx->reset();
if (should_update) {
ctx->compute_expressions(m_gstate->get_table(), pkeyed_table,
ctx->compute_expressions(m_gstate->get_table(),
expression_vocab, expression_regex_mapping);
ctx->get_expression_tables()->set_flattened(
m_gstate->get_pkeyed_table(
ctx->get_expression_tables()->m_master->get_schema(),
ctx->get_expression_tables()->m_master));

update_context_from_state<t_ctx1>(ctx, name, pkeyed_table);
}
Expand All @@ -892,8 +901,13 @@ t_gnode::_register_context(
t_ctx0* ctx = static_cast<t_ctx0*>(ptr_);
ctx->reset();
if (should_update) {
ctx->compute_expressions(m_gstate->get_table(), pkeyed_table,
ctx->compute_expressions(m_gstate->get_table(),
expression_vocab, expression_regex_mapping);
ctx->get_expression_tables()->set_flattened(
m_gstate->get_pkeyed_table(
ctx->get_expression_tables()->m_master->get_schema(),
ctx->get_expression_tables()->m_master));

update_context_from_state<t_ctx0>(ctx, name, pkeyed_table);
}
} break;
Expand All @@ -913,8 +927,13 @@ t_gnode::_register_context(
ctx->reset();

if (should_update) {
ctx->compute_expressions(m_gstate->get_table(), pkeyed_table,
ctx->compute_expressions(m_gstate->get_table(),
expression_vocab, expression_regex_mapping);
ctx->get_expression_tables()->set_flattened(
m_gstate->get_pkeyed_table(
ctx->get_expression_tables()->m_master->get_schema(),
ctx->get_expression_tables()->m_master));

update_context_from_state<t_ctx_grouped_pkey>(
ctx, name, pkeyed_table);
}
Expand Down Expand Up @@ -999,27 +1018,39 @@ t_gnode::_compute_expressions(std::shared_ptr<t_data_table> flattened_masked) {
case TWO_SIDED_CONTEXT: {
t_ctx2* ctx = static_cast<t_ctx2*>(ctxh.m_ctx);
ctx->compute_expressions(m_gstate->get_table(),
flattened_masked, expression_vocab,
expression_regex_mapping);
expression_vocab, expression_regex_mapping);
ctx->get_expression_tables()->set_flattened(
m_gstate->get_pkeyed_table(
ctx->get_expression_tables()->m_master->get_schema(),
ctx->get_expression_tables()->m_master));
} break;
case ONE_SIDED_CONTEXT: {
t_ctx1* ctx = static_cast<t_ctx1*>(ctxh.m_ctx);
ctx->compute_expressions(m_gstate->get_table(),
flattened_masked, expression_vocab,
expression_regex_mapping);
expression_vocab, expression_regex_mapping);
ctx->get_expression_tables()->set_flattened(
m_gstate->get_pkeyed_table(
ctx->get_expression_tables()->m_master->get_schema(),
ctx->get_expression_tables()->m_master));
} break;
case ZERO_SIDED_CONTEXT: {
t_ctx0* ctx = static_cast<t_ctx0*>(ctxh.m_ctx);
ctx->compute_expressions(m_gstate->get_table(),
flattened_masked, expression_vocab,
expression_regex_mapping);
expression_vocab, expression_regex_mapping);
ctx->get_expression_tables()->set_flattened(
m_gstate->get_pkeyed_table(
ctx->get_expression_tables()->m_master->get_schema(),
ctx->get_expression_tables()->m_master));
} break;
case GROUPED_PKEY_CONTEXT: {
t_ctx_grouped_pkey* ctx
= static_cast<t_ctx_grouped_pkey*>(ctxh.m_ctx);
ctx->compute_expressions(m_gstate->get_table(),
flattened_masked, expression_vocab,
expression_regex_mapping);
expression_vocab, expression_regex_mapping);
ctx->get_expression_tables()->set_flattened(
m_gstate->get_pkeyed_table(
ctx->get_expression_tables()->m_master->get_schema(),
ctx->get_expression_tables()->m_master));
} break;
case UNIT_CONTEXT:
break;
Expand Down
21 changes: 12 additions & 9 deletions cpp/perspective/src/cpp/gnode_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,33 +547,36 @@ t_gstate::get_pkey_dtype() const {

std::shared_ptr<t_data_table>
t_gstate::get_pkeyed_table() const {
return get_pkeyed_table(m_input_schema, m_table);
}

std::shared_ptr<t_data_table>
t_gstate::get_pkeyed_table(
const t_schema& schema, const std::shared_ptr<t_data_table> table) const {
// If there are no removes, just return the gstate table. Removes would
// cause m_mapping to be smaller than m_table.
if (m_mapping.size() == m_table->size())
return m_table;
if (m_mapping.size() == table->size())
return table;

// Otherwise mask out the removed rows and return the table.
auto mask = get_cpp_mask();

// count = total number of rows - number of removed rows
t_uindex table_size = mask.count();

const auto& schema_columns = m_input_schema.m_columns;
const auto& schema_columns = schema.m_columns;
t_uindex num_columns = schema_columns.size();

// Clone from the gstate master table
const std::shared_ptr<t_data_table>& master_table = m_table;

std::shared_ptr<t_data_table> rval
= std::make_shared<t_data_table>(m_input_schema, table_size);
= std::make_shared<t_data_table>(schema, table_size);
rval->init();
rval->set_size(table_size);

parallel_for(int(num_columns),
[&schema_columns, rval, master_table, &mask](int colidx) {
[&schema_columns, rval, table, &mask](int colidx) {
const std::string& colname = schema_columns[colidx];
rval->set_column(
colname, master_table->get_const_column(colname)->clone(mask));
colname, table->get_const_column(colname)->clone(mask));
}

);
Expand Down
24 changes: 16 additions & 8 deletions cpp/perspective/src/cpp/view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1580,10 +1580,10 @@ View<t_ctxunit>::to_columns(t_uindex start_row, t_uindex end_row,
bool get_pkeys, bool get_ids, bool _leaves_only, t_uindex num_sides,
bool _has_row_path, std::string nidx, t_uindex columns_length,
t_uindex group_by_length) const {

PSP_GIL_UNLOCK();
PSP_READ_LOCK(get_lock());
auto slice = get_data(start_row, end_row, start_col, end_col);
auto col_names = slice->get_column_names();
auto schema = m_ctx->get_schema();
auto& col_names = slice->get_column_names();

rapidjson::StringBuffer s;
rapidjson::Writer<rapidjson::StringBuffer> writer(s);
Expand Down Expand Up @@ -1628,9 +1628,12 @@ View<t_ctx0>::to_columns(t_uindex start_row, t_uindex end_row,
bool get_pkeys, bool get_ids, bool _leaves_only, t_uindex num_sides,
bool _has_row_path, std::string nidx, t_uindex columns_length,
t_uindex group_by_length) const {
PSP_GIL_UNLOCK();
PSP_READ_LOCK(get_lock());
auto slice = get_data(start_row, end_row, start_col, end_col);
auto col_names = slice->get_column_names();
auto schema = m_ctx->get_schema();
const std::vector<std::vector<t_tscalar>>& col_names
= slice->get_column_names();

rapidjson::StringBuffer s;
rapidjson::Writer<rapidjson::StringBuffer> writer(s);
writer.StartObject();
Expand Down Expand Up @@ -1671,8 +1674,11 @@ View<t_ctx1>::to_columns(t_uindex start_row, t_uindex end_row,
bool get_pkeys, bool get_ids, bool leaves_only, t_uindex num_sides,
bool has_row_path, std::string nidx, t_uindex columns_length,
t_uindex group_by_length) const {
PSP_GIL_UNLOCK();
PSP_READ_LOCK(get_lock());

auto slice = get_data(start_row, end_row, start_col, end_col);
auto col_names = slice->get_column_names();
const auto& col_names = slice->get_column_names();
rapidjson::StringBuffer s;
rapidjson::Writer<rapidjson::StringBuffer> writer(s);
writer.StartObject();
Expand Down Expand Up @@ -1721,8 +1727,10 @@ View<t_ctx2>::to_columns(t_uindex start_row, t_uindex end_row,
bool get_pkeys, bool get_ids, bool leaves_only, t_uindex num_sides,
bool has_row_path, std::string nidx, t_uindex columns_length,
t_uindex group_by_length) const {
auto slice = get_data(start_row, end_row, start_col, end_col);
auto col_names = slice->get_column_names();
PSP_GIL_UNLOCK();
PSP_READ_LOCK(get_lock());
const auto slice = get_data(start_row, end_row, start_col, end_col);
const auto& col_names = slice->get_column_names();
rapidjson::StringBuffer s;
rapidjson::Writer<rapidjson::StringBuffer> writer(s);
writer.StartObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ std::shared_ptr<t_expression_tables> get_expression_tables() const;
// Given shared pointers to data tables from the gnode, use them to
// compute the results of expression columns.
void compute_expressions(std::shared_ptr<t_data_table> master,
std::shared_ptr<t_data_table> flattened_masked,
t_expression_vocab& expression_vocab, t_regex_mapping& regex_mapping);

void compute_expressions(std::shared_ptr<t_data_table> master,
Expand Down
2 changes: 2 additions & 0 deletions cpp/perspective/src/include/perspective/expression_tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct t_expression_tables {
// Calculate the `t_transitions` value for each row.
void calculate_transitions(std::shared_ptr<t_data_table> existed);

void set_flattened(std::shared_ptr<t_data_table> flattened);

void reset();

t_data_table* get_table() const;
Expand Down
2 changes: 2 additions & 0 deletions cpp/perspective/src/include/perspective/gnode_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ class PERSPECTIVE_EXPORT t_gstate {
// Getters
std::shared_ptr<t_data_table> get_table() const;
std::shared_ptr<t_data_table> get_pkeyed_table() const;
std::shared_ptr<t_data_table> get_pkeyed_table(const t_schema& schema,
const std::shared_ptr<t_data_table> table) const;

const t_schema& get_input_schema() const;
const t_schema& get_output_schema() const;
Expand Down
Loading

0 comments on commit c75d8b0

Please sign in to comment.