Skip to content

Commit

Permalink
chore(dev): table rows query returns absolute index of each row in pa…
Browse files Browse the repository at this point in the history
…rent table
  • Loading branch information
bcsherma committed Jan 13, 2025
1 parent 1e1eaaf commit 2c66ab7
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ export const DatasetEditProvider: React.FC<DatasetEditProviderProps> = ({
const processRowUpdate = useCallback(
(newRow: DatasetRow, oldRow: DatasetRow): DatasetRow => {
const changedField = Object.keys(newRow).find(
key => newRow[key] !== oldRow[key] && key !== 'id'
key => newRow[key] !== oldRow[key]
);

if (changedField) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,6 @@ export const EditableDatasetView: FC<EditableDataTableViewProps> = ({
pageSize: 50,
});

// Reset sort model and pagination if we enter edit mode with sorting applied.
useEffect(() => {
if (isEditing && sortModel.length > 0) {
setPaginationModel({page: 0, pageSize: 50});
setSortModel([]);
setSortBy([]);
}
}, [isEditing, sortModel]);

const sharedRef = useContext(WeaveCHTableSourceRefContext);

const history = useHistory();
Expand Down Expand Up @@ -295,7 +286,6 @@ export const EditableDatasetView: FC<EditableDataTableViewProps> = ({
setAddedRows(prev => {
const updatedMap = new Map(prev);
const newId = `${ADDED_ROW_ID_PREFIX}${uuidv4()}`;
console.log(initialFields);
const newRow = {
___weave: {
id: newId,
Expand All @@ -310,24 +300,22 @@ export const EditableDatasetView: FC<EditableDataTableViewProps> = ({

const rows = useMemo(() => {
if (fetchQueryLoaded) {
return loadedRows.map((row, i) => {
return loadedRows.map(row => {
const digest = row.digest;
const absoluteIndex =
i + paginationModel.pageSize * paginationModel.page;
const editedRow = editedCellsMap.get(absoluteIndex);
const editedRow = editedCellsMap.get(row.original_index);
const value = flattenObjectPreservingWeaveTypes(row.val);
return {
___weave: {
id: `${digest}_${absoluteIndex}`,
index: absoluteIndex,
id: `${digest}_${row.original_index}`,
index: row.original_index,
isNew: false,
},
...(editedRow ? {...value, ...editedRow} : value),
};
});
}
return [];
}, [loadedRows, fetchQueryLoaded, editedCellsMap, paginationModel]);
}, [loadedRows, fetchQueryLoaded, editedCellsMap]);

const combinedRows = useMemo(() => {
if (
Expand Down Expand Up @@ -413,7 +401,7 @@ export const EditableDatasetView: FC<EditableDataTableViewProps> = ({
headerName: field as string,
flex: 1,
editable: isEditing,
sortable: !isEditing,
sortable: true,
filterable: false,
renderCell: (params: GridRenderCellParams) => {
const editedRow = editedCellsMap.get(params.row.___weave?.index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ export type TraceTableQueryRes = {
rows: Array<{
digest: string;
val: any;
original_index?: number;
}>;
};

Expand Down
4 changes: 3 additions & 1 deletion weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,9 @@ def _table_query_stream(
res = self._query_stream(query, parameters=pb.get_params())

for row in res:
yield tsi.TableRowSchema(digest=row[0], val=json.loads(row[1]))
yield tsi.TableRowSchema(
digest=row[0], val=json.loads(row[1]), original_index=row[2]
)

def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes:
parameters: dict[str, Any] = {
Expand Down
22 changes: 12 additions & 10 deletions weave/trace_server/table_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,22 @@ def make_natural_sort_table_query(
row_digests_selection = f"arraySlice({row_digests_selection}, 1 + {{{pb.add_param(offset)}: Int64}}, {{{pb.add_param(limit)}: Int64}})"

query = f"""
SELECT DISTINCT tr.digest, tr.val_dump, t.row_order
SELECT DISTINCT tr.digest, tr.val_dump, t.original_index + {{{pb.add_param(offset or 0)}: Int64}} - 1 as original_index
FROM table_rows tr
INNER JOIN (
SELECT row_digest, row_number() OVER () AS row_order
SELECT row_digest, original_index
FROM (
SELECT {row_digests_selection} as row_digests
SELECT {row_digests_selection} as row_digests,
arrayEnumerate(row_digests) as original_indices
FROM tables
WHERE project_id = {{{project_id_name}: String}}
AND digest = {{{digest_name}: String}}
LIMIT 1
)
ARRAY JOIN row_digests AS row_digest
ARRAY JOIN row_digests AS row_digest, original_indices AS original_index
) AS t ON tr.digest = t.row_digest
WHERE tr.project_id = {{{project_id_name}: String}}
ORDER BY row_order ASC
ORDER BY original_index ASC
"""

return query
Expand Down Expand Up @@ -88,20 +89,21 @@ def make_standard_table_query(
)

query = f"""
SELECT tr.digest, tr.val_dump, tr.row_order FROM
SELECT tr.digest, tr.val_dump, tr.original_index FROM
(
SELECT DISTINCT tr.digest, tr.val_dump, t.row_order
SELECT DISTINCT tr.digest, tr.val_dump, t.original_index
FROM table_rows tr
INNER JOIN (
SELECT row_digest, row_number() OVER () AS row_order
SELECT row_digest, original_index - 1 as original_index
FROM (
SELECT row_digests
SELECT row_digests,
arrayEnumerate(row_digests) as original_indices
FROM tables
WHERE project_id = {{{project_id_name}: String}}
AND digest = {{{digest_name}: String}}
LIMIT 1
)
ARRAY JOIN row_digests AS row_digest
ARRAY JOIN row_digests AS row_digest, original_indices AS original_index
) AS t ON tr.digest = t.row_digest
WHERE tr.project_id = {{{project_id_name}: String}}
{sql_safe_filter_clause}
Expand Down
123 changes: 91 additions & 32 deletions weave/trace_server/trace_server_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ class TableUpdateRes(BaseModel):
class TableRowSchema(BaseModel):
digest: str
val: Any
original_index: Optional[int] = None


class TableCreateRes(BaseModel):
Expand Down Expand Up @@ -896,47 +897,105 @@ def ensure_project_exists(
return EnsureProjectExistsRes(project_name=project)

# Call API
def call_start(self, req: CallStartReq) -> CallStartRes: ...
def call_end(self, req: CallEndReq) -> CallEndRes: ...
def call_read(self, req: CallReadReq) -> CallReadRes: ...
def calls_query(self, req: CallsQueryReq) -> CallsQueryRes: ...
def calls_query_stream(self, req: CallsQueryReq) -> Iterator[CallSchema]: ...
def calls_delete(self, req: CallsDeleteReq) -> CallsDeleteRes: ...
def calls_query_stats(self, req: CallsQueryStatsReq) -> CallsQueryStatsRes: ...
def call_update(self, req: CallUpdateReq) -> CallUpdateRes: ...
def call_start(self, req: CallStartReq) -> CallStartRes:
...

def call_end(self, req: CallEndReq) -> CallEndRes:
...

def call_read(self, req: CallReadReq) -> CallReadRes:
...

def calls_query(self, req: CallsQueryReq) -> CallsQueryRes:
...

def calls_query_stream(self, req: CallsQueryReq) -> Iterator[CallSchema]:
...

def calls_delete(self, req: CallsDeleteReq) -> CallsDeleteRes:
...

def calls_query_stats(self, req: CallsQueryStatsReq) -> CallsQueryStatsRes:
...

def call_update(self, req: CallUpdateReq) -> CallUpdateRes:
...

# Op API
def op_create(self, req: OpCreateReq) -> OpCreateRes: ...
def op_read(self, req: OpReadReq) -> OpReadRes: ...
def ops_query(self, req: OpQueryReq) -> OpQueryRes: ...
def op_create(self, req: OpCreateReq) -> OpCreateRes:
...

def op_read(self, req: OpReadReq) -> OpReadRes:
...

def ops_query(self, req: OpQueryReq) -> OpQueryRes:
...

# Cost API
def cost_create(self, req: CostCreateReq) -> CostCreateRes: ...
def cost_query(self, req: CostQueryReq) -> CostQueryRes: ...
def cost_purge(self, req: CostPurgeReq) -> CostPurgeRes: ...
def cost_create(self, req: CostCreateReq) -> CostCreateRes:
...

def cost_query(self, req: CostQueryReq) -> CostQueryRes:
...

def cost_purge(self, req: CostPurgeReq) -> CostPurgeRes:
...

# Obj API
def obj_create(self, req: ObjCreateReq) -> ObjCreateRes: ...
def obj_read(self, req: ObjReadReq) -> ObjReadRes: ...
def objs_query(self, req: ObjQueryReq) -> ObjQueryRes: ...
def obj_delete(self, req: ObjDeleteReq) -> ObjDeleteRes: ...
def table_create(self, req: TableCreateReq) -> TableCreateRes: ...
def table_update(self, req: TableUpdateReq) -> TableUpdateRes: ...
def table_query(self, req: TableQueryReq) -> TableQueryRes: ...
def table_query_stream(self, req: TableQueryReq) -> Iterator[TableRowSchema]: ...
def table_query_stats(self, req: TableQueryStatsReq) -> TableQueryStatsRes: ...
def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes: ...
def file_create(self, req: FileCreateReq) -> FileCreateRes: ...
def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ...
def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes: ...
def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes: ...
def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: ...
def feedback_replace(self, req: FeedbackReplaceReq) -> FeedbackReplaceRes: ...
def obj_create(self, req: ObjCreateReq) -> ObjCreateRes:
...

def obj_read(self, req: ObjReadReq) -> ObjReadRes:
...

def objs_query(self, req: ObjQueryReq) -> ObjQueryRes:
...

def obj_delete(self, req: ObjDeleteReq) -> ObjDeleteRes:
...

def table_create(self, req: TableCreateReq) -> TableCreateRes:
...

def table_update(self, req: TableUpdateReq) -> TableUpdateRes:
...

def table_query(self, req: TableQueryReq) -> TableQueryRes:
...

def table_query_stream(self, req: TableQueryReq) -> Iterator[TableRowSchema]:
...

def table_query_stats(self, req: TableQueryStatsReq) -> TableQueryStatsRes:
...

def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes:
...

def file_create(self, req: FileCreateReq) -> FileCreateRes:
...

def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes:
...

def feedback_create(self, req: FeedbackCreateReq) -> FeedbackCreateRes:
...

def feedback_query(self, req: FeedbackQueryReq) -> FeedbackQueryRes:
...

def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes:
...

def feedback_replace(self, req: FeedbackReplaceReq) -> FeedbackReplaceRes:
...

# Action API
def actions_execute_batch(
self, req: ActionsExecuteBatchReq
) -> ActionsExecuteBatchRes: ...
) -> ActionsExecuteBatchRes:
...

# Execute LLM API
def completions_create(self, req: CompletionsCreateReq) -> CompletionsCreateRes: ...
def completions_create(self, req: CompletionsCreateReq) -> CompletionsCreateRes:
...

0 comments on commit 2c66ab7

Please sign in to comment.