diff --git a/packages/grid/_modules_/grid/hooks/features/rows/useGridRowsMeta.ts b/packages/grid/_modules_/grid/hooks/features/rows/useGridRowsMeta.ts index 9e4d38f36c567..3adce142b6ccb 100644 --- a/packages/grid/_modules_/grid/hooks/features/rows/useGridRowsMeta.ts +++ b/packages/grid/_modules_/grid/hooks/features/rows/useGridRowsMeta.ts @@ -27,7 +27,9 @@ export const useGridRowsMeta = ( props: Pick<DataGridProcessedProps, 'getRowHeight' | 'pagination' | 'paginationMode'>, ): void => { const { getRowHeight, pagination, paginationMode } = props; - const rowsHeightLookup = React.useRef<{ [key: GridRowId]: number }>({}); + const rowsHeightLookup = React.useRef<{ + [key: GridRowId]: { value: number; isResized: boolean }; + }>({}); const rowHeight = useGridSelector(apiRef, gridDensityRowHeightSelector); const filterState = useGridSelector(apiRef, gridFilterStateSelector); const paginationState = useGridSelector(apiRef, gridPaginationSelector); @@ -53,11 +55,21 @@ export const useGridRowsMeta = ( const currentRowHeight = gridDensityRowHeightSelector(state, apiRef.current.instanceId); const currentPageTotalHeight = rows.reduce((acc: number, row) => { positions.push(acc); - let baseRowHeight = currentRowHeight; + let baseRowHeight: number; - if (getRowHeight) { - // Default back to base rowHeight if getRowHeight returns null or undefined. - baseRowHeight = getRowHeight({ ...row, densityFactor }) ?? currentRowHeight; + const isResized = + (rowsHeightLookup.current[row.id] && rowsHeightLookup.current[row.id].isResized) || false; + + if (isResized) { + // do not recalculate resized row height and use the value from the lookup + baseRowHeight = rowsHeightLookup.current[row.id].value; + } else { + baseRowHeight = currentRowHeight; + + if (getRowHeight) { + // Default back to base rowHeight if getRowHeight returns null or undefined. + baseRowHeight = getRowHeight({ ...row, densityFactor }) ?? currentRowHeight; + } } const heights = apiRef.current.unstable_applyPreProcessors( @@ -68,7 +80,10 @@ export const useGridRowsMeta = ( const finalRowHeight = Object.values(heights).reduce((acc2, value) => acc2 + value, 0); - rowsHeightLookup.current[row.id] = baseRowHeight; + rowsHeightLookup.current[row.id] = { + value: baseRowHeight, + isResized, + }; return acc + finalRowHeight; }, 0); @@ -82,7 +97,18 @@ export const useGridRowsMeta = ( }, [apiRef, pagination, paginationMode, getRowHeight]); const getTargetRowHeight = (rowId: GridRowId): number => - rowsHeightLookup.current[rowId] || rowHeight; + rowsHeightLookup.current[rowId]?.value || rowHeight; + + const setRowHeight = React.useCallback<GridRowsMetaApi['unstable_setRowHeight']>( + (id: GridRowId, height: number) => { + rowsHeightLookup.current[id] = { + value: height, + isResized: true, + }; + hydrateRowsMeta(); + }, + [hydrateRowsMeta], + ); // The effect is used to build the rows meta data - currentPageTotalHeight and positions. // Because of variable row height this is needed for the virtualization @@ -106,6 +132,7 @@ export const useGridRowsMeta = ( const rowsMetaApi: GridRowsMetaApi = { unstable_getRowHeight: getTargetRowHeight, + unstable_setRowHeight: setRowHeight, }; useGridApiMethod(apiRef, rowsMetaApi, 'GridRowsMetaApi'); diff --git a/packages/grid/_modules_/grid/models/api/gridRowsMetaApi.ts b/packages/grid/_modules_/grid/models/api/gridRowsMetaApi.ts index 680e6c9214b25..082b14148d94f 100644 --- a/packages/grid/_modules_/grid/models/api/gridRowsMetaApi.ts +++ b/packages/grid/_modules_/grid/models/api/gridRowsMetaApi.ts @@ -11,4 +11,11 @@ export interface GridRowsMetaApi { * @ignore - do not document. */ unstable_getRowHeight: (id: GridRowId) => number; + /** + * Updates the base height of a row. + * @param {GridRowId} id The id of the row. + * @param {number} height The new height. + * @ignore - do not document. + */ + unstable_setRowHeight: (id: GridRowId, height: number) => void; } diff --git a/packages/grid/x-data-grid-pro/src/tests/rows.DataGridPro.test.tsx b/packages/grid/x-data-grid-pro/src/tests/rows.DataGridPro.test.tsx index 6f1c4ba472772..365e87b791e18 100644 --- a/packages/grid/x-data-grid-pro/src/tests/rows.DataGridPro.test.tsx +++ b/packages/grid/x-data-grid-pro/src/tests/rows.DataGridPro.test.tsx @@ -2,7 +2,13 @@ import * as React from 'react'; import { createRenderer, fireEvent } from '@mui/monorepo/test/utils'; import { spy } from 'sinon'; import { expect } from 'chai'; -import { getCell, getRow, getColumnValues, getRows } from 'test/utils/helperFn'; +import { + getCell, + getRow, + getColumnValues, + getRows, + getColumnHeaderCell, +} from 'test/utils/helperFn'; import { GridRowModel, useGridApiRef, @@ -792,4 +798,75 @@ describe('<DataGridPro /> - Rows', () => { }).not.to.throw(); }); }); + + describe('apiRef: setRowHeight', () => { + const ROW_HEIGHT = 52; + + before(function beforeHook() { + if (isJSDOM) { + // Need layouting + this.skip(); + } + }); + + beforeEach(() => { + baselineProps = { + rows: [ + { + id: 0, + brand: 'Nike', + }, + { + id: 1, + brand: 'Adidas', + }, + { + id: 2, + brand: 'Puma', + }, + ], + columns: [{ field: 'brand', headerName: 'Brand' }], + }; + }); + + let apiRef: React.MutableRefObject<GridApi>; + + const TestCase = (props: Partial<DataGridProProps>) => { + apiRef = useGridApiRef(); + return ( + <div style={{ width: 300, height: 300 }}> + <DataGridPro {...baselineProps} apiRef={apiRef} rowHeight={ROW_HEIGHT} {...props} /> + </div> + ); + }; + + it('should change row height', () => { + const resizedRowId = 1; + render(<TestCase />); + + expect(getRow(1).clientHeight).to.equal(ROW_HEIGHT); + + apiRef.current.unstable_setRowHeight(resizedRowId, 100); + expect(getRow(resizedRowId).clientHeight).to.equal(100); + }); + + it('should preserve changed row height after sorting', () => { + const resizedRowId = 0; + const getRowHeight = spy(); + render(<TestCase getRowHeight={getRowHeight} />); + + const row = getRow(resizedRowId); + expect(row.clientHeight).to.equal(ROW_HEIGHT); + + getRowHeight.resetHistory(); + apiRef.current.unstable_setRowHeight(resizedRowId, 100); + expect(row.clientHeight).to.equal(100); + + // sort + fireEvent.click(getColumnHeaderCell(resizedRowId)); + + expect(row.clientHeight).to.equal(100); + expect(getRowHeight.neverCalledWithMatch({ id: resizedRowId })).to.equal(true); + }); + }); }); diff --git a/packages/storybook/src/stories/grid-rows.stories.tsx b/packages/storybook/src/stories/grid-rows.stories.tsx index 8387a6bd91fd0..eee15783ee7ee 100644 --- a/packages/storybook/src/stories/grid-rows.stories.tsx +++ b/packages/storybook/src/stories/grid-rows.stories.tsx @@ -6,6 +6,7 @@ import Button from '@mui/material/Button'; import Popper from '@mui/material/Popper'; import Paper from '@mui/material/Paper'; import Box from '@mui/material/Box'; +import TextField from '@mui/material/TextField'; import { GridCellValue, GridCellParams, @@ -19,6 +20,7 @@ import { MuiEvent, GridEventListener, GridRenderCellParams, + GridSelectionModel, } from '@mui/x-data-grid-pro'; import { useDemoData } from '@mui/x-data-grid-generator'; import { action } from '@storybook/addon-actions'; @@ -1070,3 +1072,67 @@ export function VariableRowHeight() { </div> ); } + +export function SetRowHeight() { + const { data } = useDemoData({ + dataSet: 'Commodity', + rowLength: 1000, + }); + + const [selectionModel, setSelectionModel] = React.useState<GridSelectionModel>([]); + React.useEffect(() => { + if (data.rows.length > 0) { + setSelectionModel([data.rows[0].id]); + } + }, [data.rows]); + + const apiRef = useGridApiRef(); + + const handleSubmit = (event: React.SyntheticEvent) => { + event.preventDefault(); + const target = event.target as typeof event.target & { + height: { value: string }; + }; + + const height = Number(target.height.value); + + selectionModel.forEach((id) => { + apiRef.current.unstable_setRowHeight(id, height); + }); + }; + + return ( + <div style={{ height: 600 }}> + <form style={{ display: 'flex', margin: '16px 0' }} onSubmit={handleSubmit}> + <TextField + name="height" + label="Row height" + size="small" + defaultValue="120" + sx={{ mr: 1 }} + /> + <Button type="submit" variant="outlined"> + Set row height + </Button> + </form> + <DataGridPro + {...data} + apiRef={apiRef} + selectionModel={selectionModel} + onSelectionModelChange={(newModel) => setSelectionModel(newModel)} + getRowHeight={({ model }) => { + if ( + model.commodity.includes('Oats') || + model.commodity.includes('Milk') || + model.commodity.includes('Soybean') || + model.commodity.includes('Rice') + ) { + return 80; + } + + return null; + }} + /> + </div> + ); +}