From 631dafa96edc8853aad35883b36a3917e31707be Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Mon, 30 Sep 2024 11:43:16 -0700 Subject: [PATCH] Pulled UI changes from tim/support_large_datasets --- .../Browse3/pages/CallPage/DataTableView.tsx | 346 +++++++++++++----- .../Home/Browse3/pages/CallPage/ValueView.tsx | 2 +- .../Home/Browse3/pages/ObjectVersionPage.tsx | 6 +- .../traceServerClientTypes.ts | 11 + .../traceServerDirectClient.ts | 11 + .../wfReactInterface/tsDataModelHooks.ts | 133 +++++++ .../wfDataModelHooksInterface.ts | 16 + 7 files changed, 428 insertions(+), 97 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/DataTableView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/DataTableView.tsx index e05df4f1b2ba..4f6d6340dac6 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/DataTableView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/DataTableView.tsx @@ -1,6 +1,12 @@ import LinkIcon from '@mui/icons-material/Link'; -import {Alert, Box} from '@mui/material'; -import {GridColDef, useGridApiRef} from '@mui/x-data-grid-pro'; +import {Box} from '@mui/material'; +import { + GridColDef, + GridEventListener, + GridPaginationModel, + GridSortModel, + useGridApiRef, +} from '@mui/x-data-grid-pro'; import { isAssignableTo, list, @@ -9,29 +15,33 @@ import { typedDict, typedDictPropertyTypes, } from '@wandb/weave/core'; +import {useDeepMemo} from '@wandb/weave/hookUtils'; import _ from 'lodash'; -import React, {FC, useCallback, useContext, useEffect, useMemo} from 'react'; +import React, { + FC, + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from 'react'; import {useHistory} from 'react-router-dom'; -import { - isWeaveObjectRef, - parseRef, - WeaveObjectRef, -} from '../../../../../../react'; +import {isWeaveObjectRef, parseRef} from '../../../../../../react'; import {flattenObjectPreservingWeaveTypes} from '../../../Browse2/browse2Util'; import {CellValue} from '../../../Browse2/CellValue'; +import {parseRefMaybe} from '../../../Browse2/SmallRef'; import { useWeaveflowCurrentRouteContext, WeaveflowPeekContext, } from '../../context'; +import {DEFAULT_PAGE_SIZE} from '../../grid/pagination'; import {StyledDataGrid} from '../../StyledDataGrid'; import {CustomWeaveTypeProjectContext} from '../../typeViews/CustomWeaveTypeDispatcher'; import {TABLE_ID_EDGE_NAME} from '../wfReactInterface/constants'; import {useWFHooks} from '../wfReactInterface/context'; -import {TableQuery} from '../wfReactInterface/wfDataModelHooksInterface'; - -// Controls the maximum number of rows to display in the table -const MAX_ROWS = 10_000; +import {SortBy} from '../wfReactInterface/traceServerClientTypes'; // Controls whether to use a table for arrays or not. export const USE_TABLE_FOR_ARRAYS = false; @@ -53,26 +63,99 @@ export const WeaveCHTable: FC<{ // Gets the source of this Table (set by a few levels up) const sourceRef = useContext(WeaveCHTableSourceRefContext); - // Retrieves the data for the table, with a limit of MAX_ROWS + 1 - const fetchQuery = useValueOfRefUri(props.tableRefUri, { - limit: MAX_ROWS + 1, - }); + const {useTableQueryStats, useTableRowsQuery} = useWFHooks(); const parsedRef = useMemo( - () => parseRef(props.tableRefUri) as WeaveObjectRef, + () => parseRefMaybe(props.tableRefUri), [props.tableRefUri] ); - // Determines if the table itself is truncated - const isTruncated = useMemo(() => { - return (fetchQuery.result ?? []).length > MAX_ROWS; - }, [fetchQuery.result]); + const lookupKey = useMemo(() => { + if ( + parsedRef == null || + !isWeaveObjectRef(parsedRef) || + parsedRef.weaveKind !== 'table' + ) { + return null; + } + return { + entity: parsedRef.entityName, + project: parsedRef.projectName, + digest: parsedRef.artifactVersion, + }; + }, [parsedRef]); + + const numRowsQuery = useTableQueryStats( + lookupKey?.entity ?? '', + lookupKey?.project ?? '', + lookupKey?.digest ?? '', + {skip: lookupKey == null} + ); + + const [limit, setLimit] = useState(DEFAULT_PAGE_SIZE); + const [offset, setOffset] = useState(0); + const [sortBy, setSortBy] = useState([]); + const [sortModel, setSortModel] = useState([]); + + const onSortModelChange = useCallback( + (model: GridSortModel) => { + setSortModel(model); + }, + [setSortModel] + ); + + const [paginationModel, setPaginationModel] = useState({ + page: 0, + pageSize: DEFAULT_PAGE_SIZE, + }); + + const onPaginationModelChange = useCallback( + (model: GridPaginationModel) => { + setPaginationModel(model); + }, + [setPaginationModel] + ); + + useEffect(() => { + setOffset(paginationModel.page * paginationModel.pageSize); + setLimit(paginationModel.pageSize); + }, [paginationModel]); + + useEffect(() => { + setSortBy( + sortModel.map(sort => ({ + field: sort.field, + direction: sort.sort === 'asc' ? 'asc' : 'desc', + })) + ); + }, [sortModel]); + + const fetchQuery = useTableRowsQuery( + lookupKey?.entity ?? '', + lookupKey?.project ?? '', + lookupKey?.digest ?? '', + undefined, + limit, + offset, + sortBy, + {skip: lookupKey == null} + ); + + const [loadedRows, setLoadedRows] = useState>([]); + + useEffect(() => { + if (!fetchQuery.loading && fetchQuery.result) { + setLoadedRows(fetchQuery.result.rows); + } + }, [fetchQuery.loading, fetchQuery.result]); + + const pagedRows = useMemo(() => { + return loadedRows ?? []; + }, [loadedRows]); - // `sourceRows` are the effective rows to display. If the table is truncated, - // we only display the first MAX_ROWS rows. - const sourceRows = useMemo(() => { - return (fetchQuery.result ?? []).slice(0, MAX_ROWS); - }, [fetchQuery.result]); + const totalRows = useMemo(() => { + return numRowsQuery.result?.count ?? pagedRows.length; + }, [numRowsQuery.result, pagedRows]); // In this block, we setup a click handler. The underlying datatable is more general // and not aware of the nuances of our links and ref model. Therefore, we handle @@ -105,31 +188,60 @@ export const WeaveCHTable: FC<{ [history, sourceRef, router] ); + const pageControl: DataTableServerSidePaginationControls = useMemo( + () => ({ + paginationModel, + onPaginationModelChange, + totalRows, + pageSizeOptions: [DEFAULT_PAGE_SIZE], + sortModel, + onSortModelChange, + }), + [ + paginationModel, + onPaginationModelChange, + totalRows, + sortModel, + onSortModelChange, + ] + ); + return ( + value={{ + entity: lookupKey?.entity ?? '', + project: lookupKey?.project ?? '', + }}> ); }; +type DataTableServerSidePaginationControls = { + paginationModel: GridPaginationModel; + onPaginationModelChange: (model: GridPaginationModel) => void; + totalRows: number; + pageSizeOptions: number[]; + sortModel: GridSortModel; + onSortModelChange: (model: GridSortModel) => void; +}; + // This is a general purpose table view that can be used to render any data. export const DataTableView: FC<{ data: Array<{[key: string]: any}>; fullHeight?: boolean; loading?: boolean; displayKey?: string; - isTruncated?: boolean; onLinkClick?: (row: any) => void; + pageControl?: DataTableServerSidePaginationControls; + autoPageSize?: boolean; }> = props => { const apiRef = useGridApiRef(); const {isPeeking} = useContext(WeaveflowPeekContext); @@ -157,31 +269,11 @@ export const DataTableView: FC<{ () => (dataAsListOfDict ?? []).map((row, i) => ({ id: i, - ...row, + data: row, })), [dataAsListOfDict] ); - // This effect will resize the columns after the table is rendered. We use a - // timeout to ensure that the table has been rendered before we resize the - // columns. - useEffect(() => { - let mounted = true; - const timeoutId = setTimeout(() => { - if (!mounted) { - return; - } - apiRef.current.autosizeColumns({ - includeHeaders: true, - includeOutliers: true, - }); - }, 0); - return () => { - mounted = false; - clearInterval(timeoutId); - }; - }, [gridRows, apiRef]); - // Next, we determine the type of the data. Previously, we used the WeaveJS // `Type` system to determine the type of the data. However, this is way to // slow for big tables and too detailed for our purposes. We just need to know @@ -242,9 +334,16 @@ export const DataTableView: FC<{ return typedDict(propertyTypes); }, [dataAsListOfDict]); + const propsDataRef = useRef(props.data); + useEffect(() => { + propsDataRef.current = props.data; + }, [props.data]); + + const objectTypeDeepMemo = useDeepMemo(objectType); + // Here we define the column spec for the table. It is based on // the type of the data and if we have a link or not. - const columnSpec: GridColDef[] = useMemo(() => { + const dataInitializedColumnSpec: GridColDef[] = useMemo(() => { const res: GridColDef[] = []; if (props.onLinkClick) { res.push({ @@ -256,21 +355,26 @@ export const DataTableView: FC<{ style={{ cursor: 'pointer', }} - onClick={() => props.onLinkClick!(props.data[params.id as number])} + onClick={() => + props.onLinkClick!(propsDataRef.current[params.id as number]) + } /> ), }); } - return [...res, ...typeToDataGridColumnSpec(objectType, isPeeking, true)]; - }, [props.onLinkClick, props.data, objectType, isPeeking]); + return [ + ...res, + ...typeToDataGridColumnSpec(objectTypeDeepMemo, isPeeking, true), + ]; + }, [props.onLinkClick, objectTypeDeepMemo, isPeeking]); // Finally, we do some math to determine the height of the table. const isSingleColumn = USE_TABLE_FOR_ARRAYS && - columnSpec.length === 1 && - columnSpec[0].field === ''; + dataInitializedColumnSpec.length === 1 && + dataInitializedColumnSpec[0].field === ''; if (isSingleColumn) { - columnSpec[0].flex = 1; + dataInitializedColumnSpec[0].flex = 1; } const hideHeader = isSingleColumn; const displayRows = 10; @@ -284,6 +388,74 @@ export const DataTableView: FC<{ (hideHeader ? 0 : headerHeight) + (hideFooter ? 0 : footerHeight) + (props.loading ? loadingHeight : contentHeight); + + const [columnSpec, setColumnSpec] = useState([]); + + // This effect will resize the columns after the table is rendered. We use a + // timeout to ensure that the table has been rendered before we resize the + // columns. + const hasLinkClick = props.onLinkClick != null; + useEffect(() => { + let mounted = true; + + // Update the column set if the column spec changes (ignore empty columns + // which can occur during loading) + setColumnSpec(curr => { + const dataFieldSet = new Set( + dataInitializedColumnSpec.map(col => col.field) + ); + const currFieldSet = new Set(curr.map(col => col.field)); + if (dataFieldSet.size > (hasLinkClick ? 1 : 0)) { + // Update if they are different + if (!_.isEqual(dataFieldSet, currFieldSet)) { + return dataInitializedColumnSpec; + } + } + return curr; + }); + + const timeoutId = setTimeout(() => { + if (!mounted) { + return; + } + apiRef.current.autosizeColumns({ + includeHeaders: true, + includeOutliers: true, + }); + // apiRef.current.forceUpdate() + }, 0); + return () => { + mounted = false; + clearInterval(timeoutId); + }; + }, [dataInitializedColumnSpec, apiRef, hasLinkClick]); + + const onColumnOrderChange: GridEventListener<'columnOrderChange'> = + useCallback(params => { + const oldIndex = params.oldIndex; + const newIndex = params.targetIndex; + setColumnSpec(currSpec => { + const col = currSpec[oldIndex]; + currSpec.splice(oldIndex, 1); + currSpec.splice(newIndex, 0, col); + return currSpec; + }); + }, []); + + const onColumnWidthChange: GridEventListener<'columnWidthChange'> = + useCallback(params => { + const field = params.colDef.field; + const newWidth = params.width; + setColumnSpec(currSpec => { + for (const col of currSpec) { + if (col.field === field) { + col.width = newWidth; + } + } + return currSpec; + }); + }, []); + return (
- {props.isTruncated && ( - - Showing {dataAsListOfDict.length.toLocaleString()} rows only. - - )}
@@ -356,6 +532,7 @@ export const typeToDataGridColumnSpec = ( ): GridColDef[] => { if (isAssignableTo(type, {type: 'typedDict', propertyTypes: {}})) { const maxWidth = window.innerWidth * (isPeeking ? 0.5 : 0.75); + const minWidth = 100; const propertyTypes = typedDictPropertyTypes(type); return Object.entries(propertyTypes).flatMap(([key, valueType]) => { const innerKey = parentKey ? `${parentKey}.${key}` : key; @@ -382,12 +559,14 @@ export const typeToDataGridColumnSpec = ( return [ { maxWidth, + minWidth, + flex: 1, type: 'string' as const, editable: false, field: innerKey, headerName: innerKey, renderCell: params => { - const listValue = params.row[innerKey]; + const listValue = params.row.data[innerKey]; if (listValue == null) { return '-'; } @@ -400,12 +579,14 @@ export const typeToDataGridColumnSpec = ( return [ { maxWidth, + minWidth, + flex: 1, type: colType, editable: editable && !disableEdits, field: innerKey, headerName: innerKey, renderCell: params => { - const data = params.row[innerKey]; + const data = params.row.data[innerKey]; return ; }, }, @@ -419,26 +600,3 @@ export const typeToDataGridColumnSpec = ( } return []; }; - -const useValueOfRefUri = (refUriStr: string, tableQuery?: TableQuery) => { - const {useRefsData} = useWFHooks(); - const data = useRefsData([refUriStr], tableQuery); - return useMemo(() => { - if (data.loading) { - return { - loading: true, - result: undefined, - }; - } - if (data.result == null || data.result.length === 0) { - return { - loading: true, - result: undefined, - }; - } - return { - loading: false, - result: data.result[0], - }; - }, [data]); -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx index 5f4b02ef63c4..7db8e1e9c1bd 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/ValueView.tsx @@ -42,7 +42,7 @@ export const ValueView = ({data, isExpanded}: ValueViewProps) => { } } if (USE_TABLE_FOR_ARRAYS && data.valueType === 'array') { - return ; + return ; } if (data.valueType === 'array' && data.value.length === 0) { return Empty List; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index 0b581e464a0f..28002ee9bde3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -180,6 +180,8 @@ const ObjectVersionPageInner: React.FC<{ return viewerData; }, [viewerData]); + const isDataset = baseObjectClass === 'Dataset' && refExtra == null; + return ( + { + return this.makeRequest( + '/table/query_stats', + req + ); + } + public feedbackCreate(req: FeedbackCreateReq): Promise { return this.makeRequest( '/feedback/create', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts index 12fa9872465e..515969919dd6 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooks.ts @@ -963,6 +963,8 @@ const useTableQuery = makeTraceServerEndpointHook< string, traceServerTypes.TraceTableQueryReq['filter'], traceServerTypes.TraceTableQueryReq['limit'], + traceServerTypes.TraceTableQueryReq['offset'], + traceServerTypes.TraceTableQueryReq['sort_by'], {skip?: boolean}? ], any[] @@ -973,6 +975,8 @@ const useTableQuery = makeTraceServerEndpointHook< digest: traceServerTypes.TraceTableQueryReq['digest'], filter: traceServerTypes.TraceTableQueryReq['filter'], limit: traceServerTypes.TraceTableQueryReq['limit'], + offset: traceServerTypes.TraceTableQueryReq['offset'], + sortBy: traceServerTypes.TraceTableQueryReq['sort_by'], opts?: {skip?: boolean} ) => ({ params: { @@ -980,6 +984,8 @@ const useTableQuery = makeTraceServerEndpointHook< digest, filter, limit, + offset, + sort_by: sortBy, }, skip: opts?.skip, }), @@ -1136,6 +1142,8 @@ const useRefsData = ( tableUriDigest, tableQueryFilter, tableQuery?.limit, + undefined, + undefined, {skip: tableRefUris.length === 0 || cachedTableResult != null} ); @@ -1192,6 +1200,129 @@ const useRefsData = ( ]); }; +const useTableRowsQuery = ( + entity: string, + project: string, + digest: string, + filter?: traceServerTypes.TraceTableQueryReq['filter'], + limit?: traceServerTypes.TraceTableQueryReq['limit'], + offset?: traceServerTypes.TraceTableQueryReq['offset'], + sortBy?: traceServerTypes.TraceTableQueryReq['sort_by'], + opts?: {skip?: boolean} +): Loadable => { + const getTsClient = useGetTraceServerClientContext(); + const [queryRes, setQueryRes] = + useState(null); + const loadingRef = useRef(false); + + const projectId = projectIdFromParts({entity, project}); + + const doFetch = useCallback(() => { + if (opts?.skip) { + return; + } + setQueryRes(null); + loadingRef.current = true; + + const req: traceServerTypes.TraceTableQueryReq = { + project_id: projectId, + digest, + filter, + limit, + offset, + sort_by: sortBy, + }; + + getTsClient() + .tableQuery(req) + .then(res => { + loadingRef.current = false; + setQueryRes(res); + }) + .catch(err => { + loadingRef.current = false; + console.error('Error fetching table rows:', err); + setQueryRes(null); + }); + }, [ + getTsClient, + projectId, + digest, + filter, + limit, + offset, + sortBy, + opts?.skip, + ]); + + useEffect(() => { + doFetch(); + }, [doFetch]); + + return useMemo(() => { + if (opts?.skip) { + return {loading: false, result: null}; + } + if (queryRes == null || loadingRef.current) { + return {loading: true, result: null}; + } + return {loading: false, result: queryRes}; + }, [queryRes, opts?.skip]); +}; + +const useTableQueryStats = ( + entity: string, + project: string, + digest: string, + opts?: {skip?: boolean} +): Loadable => { + const getTsClient = useGetTraceServerClientContext(); + const [statsRes, setStatsRes] = + useState(null); + const loadingRef = useRef(false); + + const projectId = projectIdFromParts({entity, project}); + + const doFetch = useCallback(() => { + if (opts?.skip) { + return; + } + setStatsRes(null); + loadingRef.current = true; + + const req: traceServerTypes.TraceTableQueryStatsReq = { + project_id: projectId, + digest, + }; + + getTsClient() + .tableQueryStats(req) + .then(res => { + loadingRef.current = false; + setStatsRes(res); + }) + .catch(err => { + loadingRef.current = false; + console.error('Error fetching table query stats:', err); + setStatsRes(null); + }); + }, [getTsClient, projectId, digest, opts?.skip]); + + useEffect(() => { + doFetch(); + }, [doFetch]); + + return useMemo(() => { + if (opts?.skip) { + return {loading: false, result: null}; + } + if (statsRes == null || loadingRef.current) { + return {loading: true, result: null}; + } + return {loading: false, result: statsRes}; + }, [statsRes, opts?.skip]); +}; + const useApplyMutationsToRef = (): (( refUri: string, edits: RefMutation[] @@ -1534,6 +1665,8 @@ export const tsWFDataModelHooks: WFDataModelHooksInterface = { useApplyMutationsToRef, useFeedback, useFileContent, + useTableRowsQuery, + useTableQueryStats, derived: { useChildCallsForCompare, useGetRefsType, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts index a760cbfb2074..1dba29c06db0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/wfDataModelHooksInterface.ts @@ -218,6 +218,22 @@ export type WFDataModelHooksInterface = { useObjectVersion: ( key: ObjectVersionKey | null ) => Loadable; + useTableRowsQuery: ( + entity: string, + project: string, + digest: string, + filter?: traceServerClientTypes.TraceTableQueryReq['filter'], + limit?: traceServerClientTypes.TraceTableQueryReq['limit'], + offset?: traceServerClientTypes.TraceTableQueryReq['offset'], + sortBy?: traceServerClientTypes.TraceTableQueryReq['sort_by'], + opts?: {skip?: boolean} + ) => Loadable; + useTableQueryStats: ( + entity: string, + project: string, + digest: string, + opts?: {skip?: boolean} + ) => Loadable; useRootObjectVersions: ( entity: string, project: string,