Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] profiling and optimize the runtime #1705

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Conversation

liruilong940607
Copy link
Contributor

Goal is to identify the bottleneck in the codebase and try to optimize the code for efficiency, to solve #1638

@liruilong940607
Copy link
Contributor Author

liruilong940607 commented Apr 6, 2023

Profiling using line_profiler with commands:

CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=5 kernprof -l scripts/train.py nerfacto --data data/nerfstudio/poster --max-num-iterations 1000 --viewer.quit-on-train-completion True

Locates nerfstudio.cameras.radial_and_tangential_undistort to be the current bottleneck.

Timer unit: 1e-06 s

Total time: 11.8848 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/cameras/cameras.py
Function: generate_rays at line 313

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   313                                               @profiler.time_function
   314                                               @profile
   315                                               def generate_rays(  # pylint: disable=too-many-statements
   316                                                   self,
   317                                                   camera_indices: Union[TensorType["num_rays":..., "num_cameras_batch_dims"], int],
   318                                                   coords: Optional[TensorType["num_rays":..., 2]] = None,
   319                                                   camera_opt_to_camera: Optional[TensorType["num_rays":..., 3, 4]] = None,
   320                                                   distortion_params_delta: Optional[TensorType["num_rays":..., 6]] = None,
   321                                                   keep_shape: Optional[bool] = None,
   322                                                   disable_distortion: bool = False,
   323                                                   aabb_box: Optional[SceneBox] = None,
   324                                               ) -> RayBundle:
   325                                                   """Generates rays for the given camera indices.
   326                                           
   327                                                   This function will standardize the input arguments and then call the _generate_rays_from_coords function
   328                                                   to generate the rays. Our goal is to parse the arguments and then get them into the right shape:
   329                                                       - camera_indices: (num_rays:..., num_cameras_batch_dims)
   330                                                       - coords: (num_rays:..., 2)
   331                                                       - camera_opt_to_camera: (num_rays:..., 3, 4) or None
   332                                                       - distortion_params_delta: (num_rays:..., 6) or None
   333                                           
   334                                                   Read the docstring for _generate_rays_from_coords for more information on how we generate the rays
   335                                                   after we have standardized the arguments.
   336                                           
   337                                                   We are only concerned about different combinations of camera_indices and coords matrices, and the following
   338                                                   are the 4 cases we have to deal with:
   339                                                       1. isinstance(camera_indices, int) and coords == None
   340                                                           - In this case we broadcast our camera_indices / coords shape (h, w, 1 / 2 respectively)
   341                                                       2. isinstance(camera_indices, int) and coords != None
   342                                                           - In this case, we broadcast camera_indices to the same batch dim as coords
   343                                                       3. not isinstance(camera_indices, int) and coords == None
   344                                                           - In this case, we will need to set coords so that it is of shape (h, w, num_rays, 2), and broadcast
   345                                                               all our other args to match the new definition of num_rays := (h, w) + num_rays
   346                                                       4. not isinstance(camera_indices, int) and coords != None
   347                                                           - In this case, we have nothing to do, only check that the arguments are of the correct shape
   348                                           
   349                                                   There is one more edge case we need to be careful with: when we have "jagged cameras" (ie: different heights
   350                                                   and widths for each camera). This isn't problematic when we specify coords, since coords is already a tensor.
   351                                                   When coords == None (ie: when we render out the whole image associated with this camera), we run into problems
   352                                                   since there's no way to stack each coordinate map as all coordinate maps are all different shapes. In this case,
   353                                                   we will need to flatten each individual coordinate map and concatenate them, giving us only one batch dimension,
   354                                                   regardless of the number of prepended extra batch dimensions in the camera_indices tensor.
   355                                           
   356                                           
   357                                                   Args:
   358                                                       camera_indices: Camera indices of the flattened cameras object to generate rays for.
   359                                                       coords: Coordinates of the pixels to generate rays for. If None, the full image will be rendered.
   360                                                       camera_opt_to_camera: Optional transform for the camera to world matrices.
   361                                                       distortion_params_delta: Optional delta for the distortion parameters.
   362                                                       keep_shape: If None, then we default to the regular behavior of flattening if cameras is jagged, otherwise
   363                                                           keeping dimensions. If False, we flatten at the end. If True, then we keep the shape of the
   364                                                           camera_indices and coords tensors (if we can).
   365                                                       disable_distortion: If True, disables distortion.
   366                                                       aabb_box: if not None will calculate nears and fars of the ray according to aabb box intesection
   367                                           
   368                                                   Returns:
   369                                                       Rays for the given camera indices and coords.
   370                                                   """
   371                                                   # Check the argument types to make sure they're valid and all shaped correctly
   372      1000       2485.9      2.5      0.0          assert isinstance(camera_indices, (torch.Tensor, int)), "camera_indices must be a tensor or int"
   373      1000       1197.0      1.2      0.0          assert coords is None or isinstance(coords, torch.Tensor), "coords must be a tensor or None"
   374      1000        657.0      0.7      0.0          assert camera_opt_to_camera is None or isinstance(camera_opt_to_camera, torch.Tensor)
   375      1000        494.9      0.5      0.0          assert distortion_params_delta is None or isinstance(distortion_params_delta, torch.Tensor)
   376      1000        852.6      0.9      0.0          if isinstance(camera_indices, torch.Tensor) and isinstance(coords, torch.Tensor):
   377      1000       3659.1      3.7      0.0              num_rays_shape = camera_indices.shape[:-1]
   378      1000        532.8      0.5      0.0              errormsg = "Batch dims of inputs must match when inputs are all tensors"
   379      1000       2080.8      2.1      0.0              assert coords.shape[:-1] == num_rays_shape, errormsg
   380      1000       1793.1      1.8      0.0              assert camera_opt_to_camera is None or camera_opt_to_camera.shape[:-2] == num_rays_shape, errormsg
   381      1000        558.0      0.6      0.0              assert distortion_params_delta is None or distortion_params_delta.shape[:-1] == num_rays_shape, errormsg
   382                                           
   383                                                   # If zero dimensional, we need to unsqueeze to get a batch dimension and then squeeze later
   384      1000       3330.1      3.3      0.0          if not self.shape:
   385                                                       cameras = self.reshape((1,))
   386                                                       assert torch.all(
   387                                                           torch.tensor(camera_indices == 0) if isinstance(camera_indices, int) else camera_indices == 0
   388                                                       ), "Can only index into single camera with no batch dimensions if index is zero"
   389                                                   else:
   390      1000        413.3      0.4      0.0              cameras = self
   391                                           
   392                                                   # If the camera indices are an int, then we need to make sure that the camera batch is 1D
   393      1000       1104.2      1.1      0.0          if isinstance(camera_indices, int):
   394                                                       assert (
   395                                                           len(cameras.shape) == 1
   396                                                       ), "camera_indices must be a tensor if cameras are batched with more than 1 batch dimension"
   397                                                       camera_indices = torch.tensor([camera_indices], device=cameras.device)
   398                                           
   399      1000       1319.0      1.3      0.0          assert camera_indices.shape[-1] == len(
   400      1000       1008.1      1.0      0.0              cameras.shape
   401                                                   ), "camera_indices must have shape (num_rays:..., num_cameras_batch_dims)"
   402                                           
   403                                                   # If keep_shape is True, then we need to make sure that the camera indices in question
   404                                                   # are all the same height and width and can actually be batched while maintaining the image
   405                                                   # shape
   406      1000        634.1      0.6      0.0          if keep_shape is True:
   407                                                       assert torch.all(cameras.height[camera_indices] == cameras.height[camera_indices[0]]) and torch.all(
   408                                                           cameras.width[camera_indices] == cameras.width[camera_indices[0]]
   409                                                       ), "Can only keep shape if all cameras have the same height and width"
   410                                           
   411                                                   # If the cameras don't all have same height / width, if coords is not none, we will need to generate
   412                                                   # a flat list of coords for each camera and then concatenate otherwise our rays will be jagged.
   413                                                   # Camera indices, camera_opt, and distortion will also need to be broadcasted accordingly which is non-trivial
   414      1000     201546.2    201.5      1.7          if cameras.is_jagged and coords is None and (keep_shape is None or keep_shape is False):
   415                                                       index_dim = camera_indices.shape[-1]
   416                                                       camera_indices = camera_indices.reshape(-1, index_dim)
   417                                                       _coords = [cameras.get_image_coords(index=tuple(index)).reshape(-1, 2) for index in camera_indices]
   418                                                       camera_indices = torch.cat(
   419                                                           [index.unsqueeze(0).repeat(coords.shape[0], 1) for index, coords in zip(camera_indices, _coords)],
   420                                                       )
   421                                                       coords = torch.cat(_coords, dim=0)
   422                                                       assert coords.shape[0] == camera_indices.shape[0]
   423                                                       # Need to get the coords of each indexed camera and flatten all coordinate maps and concatenate them
   424                                           
   425                                                   # The case where we aren't jagged && keep_shape (since otherwise coords is already set) and coords
   426                                                   # is None. In this case we append (h, w) to the num_rays dimensions for all tensors. In this case,
   427                                                   # each image in camera_indices has to have the same shape since otherwise we would have error'd when
   428                                                   # we checked keep_shape is valid or we aren't jagged.
   429      1000       1015.9      1.0      0.0          if coords is None:
   430                                                       index_dim = camera_indices.shape[-1]
   431                                                       index = camera_indices.reshape(-1, index_dim)[0]
   432                                                       coords: torch.Tensor = cameras.get_image_coords(index=tuple(index))  # (h, w, 2)
   433                                                       coords = coords.reshape(coords.shape[:2] + (1,) * len(camera_indices.shape[:-1]) + (2,))  # (h, w, 1..., 2)
   434                                                       coords = coords.expand(coords.shape[:2] + camera_indices.shape[:-1] + (2,))  # (h, w, num_rays, 2)
   435                                                       camera_opt_to_camera = (  # (h, w, num_rays, 3, 4) or None
   436                                                           camera_opt_to_camera.broadcast_to(coords.shape[:-1] + (3, 4))
   437                                                           if camera_opt_to_camera is not None
   438                                                           else None
   439                                                       )
   440                                                       distortion_params_delta = (  # (h, w, num_rays, 6) or None
   441                                                           distortion_params_delta.broadcast_to(coords.shape[:-1] + (6,))
   442                                                           if distortion_params_delta is not None
   443                                                           else None
   444                                                       )
   445                                           
   446                                                   # If camera indices was an int or coords was none, we need to broadcast our indices along batch dims
   447      1000      19520.5     19.5      0.2          camera_indices = camera_indices.broadcast_to(coords.shape[:-1] + (len(cameras.shape),)).to(torch.long)
   448                                           
   449                                                   # Checking our tensors have been standardized
   450      1000       1517.0      1.5      0.0          assert isinstance(coords, torch.Tensor) and isinstance(camera_indices, torch.Tensor)
   451      1000       2415.0      2.4      0.0          assert camera_indices.shape[-1] == len(cameras.shape)
   452      1000       3257.2      3.3      0.0          assert camera_opt_to_camera is None or camera_opt_to_camera.shape[:-2] == coords.shape[:-1]
   453      1000        776.6      0.8      0.0          assert distortion_params_delta is None or distortion_params_delta.shape[:-1] == coords.shape[:-1]
   454                                           
   455                                                   # This will do the actual work of generating the rays now that we have standardized the inputs
   456                                                   # raybundle.shape == (num_rays) when done
   457                                                   # pylint: disable=protected-access
   458      1000   11629989.5  11630.0     97.9          raybundle = cameras._generate_rays_from_coords(
   459      1000        662.4      0.7      0.0              camera_indices, coords, camera_opt_to_camera, distortion_params_delta, disable_distortion=disable_distortion
   460                                                   )
   461                                           
   462                                                   # If we have mandated that we don't keep the shape, then we flatten
   463      1000        866.5      0.9      0.0          if keep_shape is False:
   464                                                       raybundle = raybundle.flatten()
   465                                           
   466      1000        746.4      0.7      0.0          if aabb_box:
   467                                                       with torch.no_grad():
   468                                                           tensor_aabb = Parameter(aabb_box.aabb.flatten(), requires_grad=False)
   469                                           
   470                                                           rays_o = raybundle.origins.contiguous()
   471                                                           rays_d = raybundle.directions.contiguous()
   472                                           
   473                                                           tensor_aabb = tensor_aabb.to(rays_o.device)
   474                                                           shape = rays_o.shape
   475                                           
   476                                                           rays_o = rays_o.reshape((-1, 3))
   477                                                           rays_d = rays_d.reshape((-1, 3))
   478                                           
   479                                                           t_min, t_max = nerfstudio.utils.math.intersect_aabb(rays_o, rays_d, tensor_aabb)
   480                                           
   481                                                           t_min = t_min.reshape([shape[0], shape[1], 1])
   482                                                           t_max = t_max.reshape([shape[0], shape[1], 1])
   483                                           
   484                                                           raybundle.nears = t_min
   485                                                           raybundle.fars = t_max
   486                                           
   487                                                   # TODO: We should have to squeeze the last dimension here if we started with zero batch dims, but never have to,
   488                                                   # so there might be a rogue squeeze happening somewhere, and this may cause some unintended behaviour
   489                                                   # that we haven't caught yet with tests
   490      1000        382.9      0.4      0.0          return raybundle

Total time: 11.3149 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/cameras/cameras.py
Function: _generate_rays_from_coords at line 493

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   493                                               @profile
   494                                               def _generate_rays_from_coords(
   495                                                   self,
   496                                                   camera_indices: TensorType["num_rays":..., "num_cameras_batch_dims"],
   497                                                   coords: TensorType["num_rays":..., 2],
   498                                                   camera_opt_to_camera: Optional[TensorType["num_rays":..., 3, 4]] = None,
   499                                                   distortion_params_delta: Optional[TensorType["num_rays":..., 6]] = None,
   500                                                   disable_distortion: bool = False,
   501                                               ) -> RayBundle:
   502                                                   """Generates rays for the given camera indices and coords where self isn't jagged
   503                                           
   504                                                   This is a fairly complex function, so let's break this down slowly.
   505                                           
   506                                                   Shapes involved:
   507                                                       - num_rays: This is your output raybundle shape. It dictates the number and shape of the rays generated
   508                                                       - num_cameras_batch_dims: This is the number of dimensions of our camera
   509                                           
   510                                                   Args:
   511                                                       camera_indices: Camera indices of the flattened cameras object to generate rays for.
   512                                                           The shape of this is such that indexing into camera_indices["num_rays":...] will return the
   513                                                           index into each batch dimension of the camera in order to get the correct camera specified by
   514                                                           "num_rays".
   515                                           
   516                                                           Example:
   517                                                               >>> cameras = Cameras(...)
   518                                                               >>> cameras.shape
   519                                                                   (2, 3, 4)
   520                                           
   521                                                               >>> camera_indices = torch.tensor([0, 0, 0]) # We need an axis of length 3 since cameras.ndim == 3
   522                                                               >>> camera_indices.shape
   523                                                                   (3,)
   524                                                               >>> coords = torch.tensor([1,1])
   525                                                               >>> coords.shape
   526                                                                   (2,)
   527                                                               >>> out_rays = cameras.generate_rays(camera_indices=camera_indices, coords = coords)
   528                                                                   # This will generate a RayBundle with a single ray for the
   529                                                                   # camera at cameras[0,0,0] at image coordinates (1,1), so out_rays.shape == ()
   530                                                               >>> out_rays.shape
   531                                                                   ()
   532                                           
   533                                                               >>> camera_indices = torch.tensor([[0,0,0]])
   534                                                               >>> camera_indices.shape
   535                                                                   (1, 3)
   536                                                               >>> coords = torch.tensor([[1,1]])
   537                                                               >>> coords.shape
   538                                                                   (1, 2)
   539                                                               >>> out_rays = cameras.generate_rays(camera_indices=camera_indices, coords = coords)
   540                                                                   # This will generate a RayBundle with a single ray for the
   541                                                                   # camera at cameras[0,0,0] at point (1,1), so out_rays.shape == (1,)
   542                                                                   # since we added an extra dimension in front of camera_indices
   543                                                               >>> out_rays.shape
   544                                                                   (1,)
   545                                           
   546                                                           If you want more examples, check tests/cameras/test_cameras and the function check_generate_rays_shape
   547                                           
   548                                                           The bottom line is that for camera_indices: (num_rays:..., num_cameras_batch_dims), num_rays is the
   549                                                           output shape and if you index into the output RayBundle with some indices [i:...], if you index into
   550                                                           camera_indices with camera_indices[i:...] as well, you will get a 1D tensor containing the batch
   551                                                           indices into the original cameras object corresponding to that ray (ie: you will get the camera
   552                                                           from our batched cameras corresponding to the ray at RayBundle[i:...]).
   553                                           
   554                                                       coords: Coordinates of the pixels to generate rays for. If None, the full image will be rendered, meaning
   555                                                           height and width get prepended to the num_rays dimensions. Indexing into coords with [i:...] will
   556                                                           get you the image coordinates [x, y] of that specific ray located at output RayBundle[i:...].
   557                                           
   558                                                       camera_opt_to_camera: Optional transform for the camera to world matrices.
   559                                                           In terms of shape, it follows the same rules as coords, but indexing into it with [i:...] gets you
   560                                                           the 2D camera to world transform matrix for the camera optimization at RayBundle[i:...].
   561                                           
   562                                                       distortion_params_delta: Optional delta for the distortion parameters.
   563                                                           In terms of shape, it follows the same rules as coords, but indexing into it with [i:...] gets you
   564                                                           the 1D tensor with the 6 distortion parameters for the camera optimization at RayBundle[i:...].
   565                                           
   566                                                       disable_distortion: If True, disables distortion.
   567                                           
   568                                                   Returns:
   569                                                       Rays for the given camera indices and coords. RayBundle.shape == num_rays
   570                                                   """
   571                                                   # Make sure we're on the right devices
   572      1000      74631.4     74.6      0.7          camera_indices = camera_indices.to(self.device)
   573      1000       3888.6      3.9      0.0          coords = coords.to(self.device)
   574                                           
   575                                                   # Checking to make sure everything is of the right shape and type
   576      1000       2337.8      2.3      0.0          num_rays_shape = camera_indices.shape[:-1]
   577      1000       4195.5      4.2      0.0          assert camera_indices.shape == num_rays_shape + (self.ndim,)
   578      1000       1698.2      1.7      0.0          assert coords.shape == num_rays_shape + (2,)
   579      1000       1052.9      1.1      0.0          assert coords.shape[-1] == 2
   580      1000       2006.7      2.0      0.0          assert camera_opt_to_camera is None or camera_opt_to_camera.shape == num_rays_shape + (3, 4)
   581      1000        433.5      0.4      0.0          assert distortion_params_delta is None or distortion_params_delta.shape == num_rays_shape + (6,)
   582                                           
   583                                                   # Here, we've broken our indices down along the num_cameras_batch_dims dimension allowing us to index by all
   584                                                   # of our output rays at each dimension of our cameras object
   585      1000      14985.0     15.0      0.1          true_indices = [camera_indices[..., i] for i in range(camera_indices.shape[-1])]
   586                                           
   587                                                   # Get all our focal lengths, principal points and make sure they are the right shapes
   588      1000       4790.1      4.8      0.0          y = coords[..., 0]  # (num_rays,) get rid of the last dimension
   589      1000       4490.4      4.5      0.0          x = coords[..., 1]  # (num_rays,) get rid of the last dimension
   590      1000      88567.4     88.6      0.8          fx, fy = self.fx[true_indices].squeeze(-1), self.fy[true_indices].squeeze(-1)  # (num_rays,)
   591      1000      68001.5     68.0      0.6          cx, cy = self.cx[true_indices].squeeze(-1), self.cy[true_indices].squeeze(-1)  # (num_rays,)
   592      1000        683.8      0.7      0.0          assert (
   593      1000       1584.5      1.6      0.0              y.shape == num_rays_shape
   594      1000        798.7      0.8      0.0              and x.shape == num_rays_shape
   595      1000        715.2      0.7      0.0              and fx.shape == num_rays_shape
   596      1000        695.5      0.7      0.0              and fy.shape == num_rays_shape
   597      1000        689.5      0.7      0.0              and cx.shape == num_rays_shape
   598      1000        682.2      0.7      0.0              and cy.shape == num_rays_shape
   599                                                   ), (
   600                                                       str(num_rays_shape)
   601                                                       + str(y.shape)
   602                                                       + str(x.shape)
   603                                                       + str(fx.shape)
   604                                                       + str(fy.shape)
   605                                                       + str(cx.shape)
   606                                                       + str(cy.shape)
   607                                                   )
   608                                           
   609                                                   # Get our image coordinates and image coordinates offset by 1 (offsets used for dx, dy calculations)
   610                                                   # Also make sure the shapes are correct
   611      1000     165038.0    165.0      1.5          coord = torch.stack([(x - cx) / fx, -(y - cy) / fy], -1)  # (num_rays, 2)
   612      1000     164762.2    164.8      1.5          coord_x_offset = torch.stack([(x - cx + 1) / fx, -(y - cy) / fy], -1)  # (num_rays, 2)
   613      1000     154118.6    154.1      1.4          coord_y_offset = torch.stack([(x - cx) / fx, -(y - cy + 1) / fy], -1)  # (num_rays, 2)
   614      1000        662.1      0.7      0.0          assert (
   615      1000       3212.4      3.2      0.0              coord.shape == num_rays_shape + (2,)
   616      1000       1375.7      1.4      0.0              and coord_x_offset.shape == num_rays_shape + (2,)
   617      1000       1231.9      1.2      0.0              and coord_y_offset.shape == num_rays_shape + (2,)
   618                                                   )
   619                                           
   620                                                   # Stack image coordinates and image coordinates offset by 1, check shapes too
   621      1000      30781.9     30.8      0.3          coord_stack = torch.stack([coord, coord_x_offset, coord_y_offset], dim=0)  # (3, num_rays, 2)
   622      1000       2180.3      2.2      0.0          assert coord_stack.shape == (3,) + num_rays_shape + (2,)
   623                                           
   624                                                   # Undistorts our images according to our distortion parameters
   625      1000        560.6      0.6      0.0          if not disable_distortion:
   626      1000        498.1      0.5      0.0              distortion_params = None
   627      1000       1077.8      1.1      0.0              if self.distortion_params is not None:
   628      1000      35734.2     35.7      0.3                  distortion_params = self.distortion_params[true_indices]
   629      1000        752.8      0.8      0.0                  if distortion_params_delta is not None:
   630                                                               distortion_params = distortion_params + distortion_params_delta
   631                                                       elif distortion_params_delta is not None:
   632                                                           distortion_params = distortion_params_delta
   633                                           
   634                                                       # Do not apply distortion for equirectangular images
   635      1000        513.0      0.5      0.0              if distortion_params is not None:
   636      1000      79204.1     79.2      0.7                  mask = (self.camera_type[true_indices] != CameraType.EQUIRECTANGULAR.value).squeeze(-1)  # (num_rays)
   637      1000      33447.0     33.4      0.3                  coord_mask = torch.stack([mask, mask, mask], dim=0)
   638      1000      60928.0     60.9      0.5                  if mask.any():
   639      1000    7464466.8   7464.5     66.0                      coord_stack[coord_mask, :] = camera_utils.radial_and_tangential_undistort(
   640      1000     175226.2    175.2      1.5                          coord_stack[coord_mask, :].reshape(3, -1, 2),
   641      1000     105816.8    105.8      0.9                          distortion_params[mask, :],
   642      1000        863.3      0.9      0.0                      ).reshape(-1, 2)
   643                                           
   644                                                   # Make sure after we have undistorted our images, the shapes are still correct
   645      1000       4192.1      4.2      0.0          assert coord_stack.shape == (3,) + num_rays_shape + (2,)
   646                                           
   647                                                   # Gets our directions for all our rays in camera coordinates and checks shapes at the end
   648                                                   # Here, directions_stack is of shape (3, num_rays, 3)
   649                                                   # directions_stack[0] is the direction for ray in camera coordinates
   650                                                   # directions_stack[1] is the direction for ray in camera coordinates offset by 1 in x
   651                                                   # directions_stack[2] is the direction for ray in camera coordinates offset by 1 in y
   652      1000     140261.8    140.3      1.2          cam_types = torch.unique(self.camera_type, sorted=False)
   653      1000      16649.5     16.6      0.1          directions_stack = torch.empty((3,) + num_rays_shape + (3,), device=self.device)
   654      1000     113806.5    113.8      1.0          if CameraType.PERSPECTIVE.value in cam_types:
   655      1000      80066.9     80.1      0.7              mask = (self.camera_type[true_indices] == CameraType.PERSPECTIVE.value).squeeze(-1)  # (num_rays)
   656      1000      40994.0     41.0      0.4              mask = torch.stack([mask, mask, mask], dim=0)
   657      1000     292195.9    292.2      2.6              directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0], mask).float()
   658      1000     260874.6    260.9      2.3              directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1], mask).float()
   659      1000      43487.6     43.5      0.4              directions_stack[..., 2][mask] = -1.0
   660                                           
   661      1000      95889.2     95.9      0.8          if CameraType.FISHEYE.value in cam_types:
   662                                                       mask = (self.camera_type[true_indices] == CameraType.FISHEYE.value).squeeze(-1)  # (num_rays)
   663                                                       mask = torch.stack([mask, mask, mask], dim=0)
   664                                           
   665                                                       theta = torch.sqrt(torch.sum(coord_stack**2, dim=-1))
   666                                                       theta = torch.clip(theta, 0.0, math.pi)
   667                                           
   668                                                       sin_theta = torch.sin(theta)
   669                                           
   670                                                       directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0] * sin_theta / theta, mask).float()
   671                                                       directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1] * sin_theta / theta, mask).float()
   672                                                       directions_stack[..., 2][mask] = -torch.masked_select(torch.cos(theta), mask).float()
   673                                           
   674      1000      81399.7     81.4      0.7          if CameraType.EQUIRECTANGULAR.value in cam_types:
   675                                                       mask = (self.camera_type[true_indices] == CameraType.EQUIRECTANGULAR.value).squeeze(-1)  # (num_rays)
   676                                                       mask = torch.stack([mask, mask, mask], dim=0)
   677                                           
   678                                                       # For equirect, fx = fy = height = width/2
   679                                                       # Then coord[..., 0] goes from -1 to 1 and coord[..., 1] goes from -1/2 to 1/2
   680                                                       theta = -torch.pi * coord_stack[..., 0]  # minus sign for right-handed
   681                                                       phi = torch.pi * (0.5 - coord_stack[..., 1])
   682                                                       # use spherical in local camera coordinates (+y up, x=0 and z<0 is theta=0)
   683                                                       directions_stack[..., 0][mask] = torch.masked_select(-torch.sin(theta) * torch.sin(phi), mask).float()
   684                                                       directions_stack[..., 1][mask] = torch.masked_select(torch.cos(phi), mask).float()
   685                                                       directions_stack[..., 2][mask] = torch.masked_select(-torch.cos(theta) * torch.sin(phi), mask).float()
   686                                           
   687      1000      22204.9     22.2      0.2          for value in cam_types:
   688      1000      50056.3     50.1      0.4              if value not in [CameraType.PERSPECTIVE.value, CameraType.FISHEYE.value, CameraType.EQUIRECTANGULAR.value]:
   689                                                           raise ValueError(f"Camera type {value} not supported.")
   690                                           
   691      1000       3335.4      3.3      0.0          assert directions_stack.shape == (3,) + num_rays_shape + (3,)
   692                                           
   693      1000      40697.1     40.7      0.4          c2w = self.camera_to_worlds[true_indices]
   694      1000       3709.2      3.7      0.0          assert c2w.shape == num_rays_shape + (3, 4)
   695                                           
   696      1000        834.7      0.8      0.0          if camera_opt_to_camera is not None:
   697      1000     415472.2    415.5      3.7              c2w = pose_utils.multiply(c2w, camera_opt_to_camera)
   698      1000      11014.6     11.0      0.1          rotation = c2w[..., :3, :3]  # (..., 3, 3)
   699      1000       4103.3      4.1      0.0          assert rotation.shape == num_rays_shape + (3, 3)
   700                                           
   701      1000      42496.4     42.5      0.4          directions_stack = torch.sum(
   702      1000      44016.9     44.0      0.4              directions_stack[..., None, :] * rotation, dim=-1
   703                                                   )  # (..., 1, 3) * (..., 3, 3) -> (..., 3)
   704      1000     194016.4    194.0      1.7          directions_stack, directions_norm = camera_utils.normalize_with_norm(directions_stack, -1)
   705      1000       3257.8      3.3      0.0          assert directions_stack.shape == (3,) + num_rays_shape + (3,)
   706                                           
   707      1000      11965.4     12.0      0.1          origins = c2w[..., :3, 3]  # (..., 3)
   708      1000       2850.4      2.9      0.0          assert origins.shape == num_rays_shape + (3,)
   709                                           
   710      1000       6009.2      6.0      0.1          directions = directions_stack[0]
   711      1000       1899.7      1.9      0.0          assert directions.shape == num_rays_shape + (3,)
   712                                           
   713                                                   # norms of the vector going between adjacent coords, giving us dx and dy per output ray
   714      1000     144113.2    144.1      1.3          dx = torch.sqrt(torch.sum((directions - directions_stack[1]) ** 2, dim=-1))  # ("num_rays":...,)
   715      1000     113392.9    113.4      1.0          dy = torch.sqrt(torch.sum((directions - directions_stack[2]) ** 2, dim=-1))  # ("num_rays":...,)
   716      1000       3008.5      3.0      0.0          assert dx.shape == num_rays_shape and dy.shape == num_rays_shape
   717                                           
   718      1000      33385.7     33.4      0.3          pixel_area = (dx * dy)[..., None]  # ("num_rays":..., 1)
   719      1000       2697.6      2.7      0.0          assert pixel_area.shape == num_rays_shape + (1,)
   720                                           
   721      1000       1534.2      1.5      0.0          times = self.times[camera_indices, 0] if self.times is not None else None
   722                                           
   723      1000     209191.4    209.2      1.8          return RayBundle(
   724      1000        660.2      0.7      0.0              origins=origins,
   725      1000        539.5      0.5      0.0              directions=directions,
   726      1000        578.5      0.6      0.0              pixel_area=pixel_area,
   727      1000        662.9      0.7      0.0              camera_indices=camera_indices,
   728      1000        622.1      0.6      0.0              times=times,
   729      1000      17423.7     17.4      0.2              metadata={"directions_norm": directions_norm[0].detach()},
   730                                                   )

Total time: 18.8698 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/data/datamanagers/base_datamanager.py
Function: next_train at line 514

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   514                                               @profiler.time_function
   515                                               @profile
   516                                               def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
   517                                                   """Returns the next batch of data from the train dataloader."""
   518      1000      14462.6     14.5      0.1          self.train_count += 1
   519      1000       6900.0      6.9      0.0          image_batch = next(self.iter_train_image_dataloader)
   520      1000        913.3      0.9      0.0          assert self.train_pixel_sampler is not None
   521      1000    4180161.5   4180.2     22.2          batch = self.train_pixel_sampler.sample(image_batch)
   522      1000        925.7      0.9      0.0          ray_indices = batch["indices"]
   523      1000   14665925.9  14665.9     77.7          ray_bundle = self.train_ray_generator(ray_indices)
   524      1000        467.8      0.5      0.0          return ray_bundle, batch

Total time: 14.6002 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/model_components/ray_generators.py
Function: forward at line 42

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    42                                               @profiler.time_function
    43                                               @profile
    44                                               def forward(self, ray_indices: TensorType["num_rays", 3]) -> RayBundle:
    45                                                   """Index into the cameras to generate the rays.
    46                                           
    47                                                   Args:
    48                                                       ray_indices: Contains camera, row, and col indices for target rays.
    49                                                   """
    50      1000      12212.8     12.2      0.1          c = ray_indices[:, 0]  # camera indices
    51      1000       6210.8      6.2      0.0          y = ray_indices[:, 1]  # row indices
    52      1000       5762.0      5.8      0.0          x = ray_indices[:, 2]  # col indices
    53      1000     156186.8    156.2      1.1          coords = self.image_coords[y, x]
    54                                           
    55      1000    2453782.2   2453.8     16.8          camera_opt_to_camera = self.pose_optimizer(c)
    56                                           
    57      1000   11956216.8  11956.2     81.9          ray_bundle = self.cameras.generate_rays(
    58      1000       8518.4      8.5      0.1              camera_indices=c.unsqueeze(-1),
    59      1000        680.0      0.7      0.0              coords=coords,
    60      1000        316.4      0.3      0.0              camera_opt_to_camera=camera_opt_to_camera,
    61                                                   )
    62      1000        341.1      0.3      0.0          return ray_bundle

Total time: 43.9297 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/pipelines/base_pipeline.py
Function: get_train_loss_dict at line 268

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   268                                               @profiler.time_function
   269                                               @profile
   270                                               def get_train_loss_dict(self, step: int):
   271                                                   """This function gets your training loss dict. This will be responsible for
   272                                                   getting the next batch of data from the DataManager and interfacing with the
   273                                                   Model class, feeding the data to the model's forward function.
   274                                           
   275                                                   Args:
   276                                                       step: current iteration step to update sampler if using DDP (distributed)
   277                                                   """
   278      1000   18911764.8  18911.8     43.1          ray_bundle, batch = self.datamanager.next_train(step)
   279      1000   20886524.8  20886.5     47.5          model_outputs = self.model(ray_bundle)
   280      1000    2058127.0   2058.1      4.7          metrics_dict = self.model.get_metrics_dict(model_outputs, batch)
   281                                           
   282      1000       3843.2      3.8      0.0          if self.config.datamanager.camera_optimizer is not None:
   283      1000       1244.6      1.2      0.0              camera_opt_param_group = self.config.datamanager.camera_optimizer.param_group
   284      1000      45209.3     45.2      0.1              if camera_opt_param_group in self.datamanager.get_param_groups():
   285                                                           # Report the camera optimization metrics
   286      1000        576.1      0.6      0.0                  metrics_dict["camera_opt_translation"] = (
   287      1000     121657.4    121.7      0.3                      self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, :3].norm()
   288                                                           )
   289      1000        588.1      0.6      0.0                  metrics_dict["camera_opt_rotation"] = (
   290      1000      92025.0     92.0      0.2                      self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, 3:].norm()
   291                                                           )
   292                                           
   293      1000    1807498.8   1807.5      4.1          loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict)
   294                                           
   295      1000        612.2      0.6      0.0          return model_outputs, loss_dict, metrics_dict

@liruilong940607
Copy link
Contributor Author

Nerfacc takes care of ray undistortion. data loading consumes 43% -> 34%

Timer unit: 1e-06 s

Total time: 5.89843 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/cameras/cameras.py
Function: generate_rays at line 312

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   312                                               @profiler.time_function
   313                                               @profile
   314                                               def generate_rays(  # pylint: disable=too-many-statements
   315                                                   self,
   316                                                   camera_indices: Union[TensorType["num_rays":..., "num_cameras_batch_dims"], int],
   317                                                   coords: Optional[TensorType["num_rays":..., 2]] = None,
   318                                                   camera_opt_to_camera: Optional[TensorType["num_rays":..., 3, 4]] = None,
   319                                                   distortion_params_delta: Optional[TensorType["num_rays":..., 6]] = None,
   320                                                   keep_shape: Optional[bool] = None,
   321                                                   disable_distortion: bool = False,
   322                                                   aabb_box: Optional[SceneBox] = None,
   323                                               ) -> RayBundle:
   370                                                   # Check the argument types to make sure they're valid and all shaped correctly
   371      1000       2542.3      2.5      0.0          assert isinstance(camera_indices, (torch.Tensor, int)), "camera_indices must be a tensor or int"
   372      1000       1301.2      1.3      0.0          assert coords is None or isinstance(coords, torch.Tensor), "coords must be a tensor or None"
   373      1000        797.3      0.8      0.0          assert camera_opt_to_camera is None or isinstance(camera_opt_to_camera, torch.Tensor)
   374      1000        544.7      0.5      0.0          assert distortion_params_delta is None or isinstance(distortion_params_delta, torch.Tensor)
   375      1000       1019.2      1.0      0.0          if isinstance(camera_indices, torch.Tensor) and isinstance(coords, torch.Tensor):
   376      1000       3863.7      3.9      0.1              num_rays_shape = camera_indices.shape[:-1]
   377      1000        463.0      0.5      0.0              errormsg = "Batch dims of inputs must match when inputs are all tensors"
   378      1000       1951.7      2.0      0.0              assert coords.shape[:-1] == num_rays_shape, errormsg
   379      1000       2065.4      2.1      0.0              assert camera_opt_to_camera is None or camera_opt_to_camera.shape[:-2] == num_rays_shape, errormsg
   380      1000        571.5      0.6      0.0              assert distortion_params_delta is None or distortion_params_delta.shape[:-1] == num_rays_shape, errormsg
   381                                           
   382                                                   # If zero dimensional, we need to unsqueeze to get a batch dimension and then squeeze later
   383      1000       3057.7      3.1      0.1          if not self.shape:
   384                                                       cameras = self.reshape((1,))
   385                                                       assert torch.all(
   386                                                           torch.tensor(camera_indices == 0) if isinstance(camera_indices, int) else camera_indices == 0
   387                                                       ), "Can only index into single camera with no batch dimensions if index is zero"
   388                                                   else:
   389      1000        528.9      0.5      0.0              cameras = self
   390                                           
   391                                                   # If the camera indices are an int, then we need to make sure that the camera batch is 1D
   392      1000       1165.2      1.2      0.0          if isinstance(camera_indices, int):
   393                                                       assert (
   394                                                           len(cameras.shape) == 1
   395                                                       ), "camera_indices must be a tensor if cameras are batched with more than 1 batch dimension"
   396                                                       camera_indices = torch.tensor([camera_indices], device=cameras.device)
   397                                           
   398      1000       1319.7      1.3      0.0          assert camera_indices.shape[-1] == len(
   399      1000       1130.7      1.1      0.0              cameras.shape
   400                                                   ), "camera_indices must have shape (num_rays:..., num_cameras_batch_dims)"
   401                                           
   402                                                   # If keep_shape is True, then we need to make sure that the camera indices in question
   403                                                   # are all the same height and width and can actually be batched while maintaining the image
   404                                                   # shape
   405      1000        698.3      0.7      0.0          if keep_shape is True:
   406                                                       assert torch.all(cameras.height[camera_indices] == cameras.height[camera_indices[0]]) and torch.all(
   407                                                           cameras.width[camera_indices] == cameras.width[camera_indices[0]]
   408                                                       ), "Can only keep shape if all cameras have the same height and width"
   409                                           
   410                                                   # If the cameras don't all have same height / width, if coords is not none, we will need to generate
   411                                                   # a flat list of coords for each camera and then concatenate otherwise our rays will be jagged.
   412                                                   # Camera indices, camera_opt, and distortion will also need to be broadcasted accordingly which is non-trivial
   413      1000     207106.9    207.1      3.5          if cameras.is_jagged and coords is None and (keep_shape is None or keep_shape is False):
   414                                                       index_dim = camera_indices.shape[-1]
   415                                                       camera_indices = camera_indices.reshape(-1, index_dim)
   416                                                       _coords = [cameras.get_image_coords(index=tuple(index)).reshape(-1, 2) for index in camera_indices]
   417                                                       camera_indices = torch.cat(
   418                                                           [index.unsqueeze(0).repeat(coords.shape[0], 1) for index, coords in zip(camera_indices, _coords)],
   419                                                       )
   420                                                       coords = torch.cat(_coords, dim=0)
   421                                                       assert coords.shape[0] == camera_indices.shape[0]
   422                                                       # Need to get the coords of each indexed camera and flatten all coordinate maps and concatenate them
   423                                           
   424                                                   # The case where we aren't jagged && keep_shape (since otherwise coords is already set) and coords
   425                                                   # is None. In this case we append (h, w) to the num_rays dimensions for all tensors. In this case,
   426                                                   # each image in camera_indices has to have the same shape since otherwise we would have error'd when
   427                                                   # we checked keep_shape is valid or we aren't jagged.
   428      1000        831.3      0.8      0.0          if coords is None:
   429                                                       index_dim = camera_indices.shape[-1]
   430                                                       index = camera_indices.reshape(-1, index_dim)[0]
   431                                                       coords: torch.Tensor = cameras.get_image_coords(index=tuple(index))  # (h, w, 2)
   432                                                       coords = coords.reshape(coords.shape[:2] + (1,) * len(camera_indices.shape[:-1]) + (2,))  # (h, w, 1..., 2)
   433                                                       coords = coords.expand(coords.shape[:2] + camera_indices.shape[:-1] + (2,))  # (h, w, num_rays, 2)
   434                                                       camera_opt_to_camera = (  # (h, w, num_rays, 3, 4) or None
   435                                                           camera_opt_to_camera.broadcast_to(coords.shape[:-1] + (3, 4))
   436                                                           if camera_opt_to_camera is not None
   437                                                           else None
   438                                                       )
   439                                                       distortion_params_delta = (  # (h, w, num_rays, 6) or None
   440                                                           distortion_params_delta.broadcast_to(coords.shape[:-1] + (6,))
   441                                                           if distortion_params_delta is not None
   442                                                           else None
   443                                                       )
   444                                           
   445                                                   # If camera indices was an int or coords was none, we need to broadcast our indices along batch dims
   446      1000      20603.4     20.6      0.3          camera_indices = camera_indices.broadcast_to(coords.shape[:-1] + (len(cameras.shape),)).to(torch.long)
   447                                           
   448                                                   # Checking our tensors have been standardized
   449      1000       1520.5      1.5      0.0          assert isinstance(coords, torch.Tensor) and isinstance(camera_indices, torch.Tensor)
   450      1000       2431.9      2.4      0.0          assert camera_indices.shape[-1] == len(cameras.shape)
   451      1000       3095.2      3.1      0.1          assert camera_opt_to_camera is None or camera_opt_to_camera.shape[:-2] == coords.shape[:-1]
   452      1000        622.4      0.6      0.0          assert distortion_params_delta is None or distortion_params_delta.shape[:-1] == coords.shape[:-1]
   453                                           
   454                                                   # This will do the actual work of generating the rays now that we have standardized the inputs
   455                                                   # raybundle.shape == (num_rays) when done
   456                                                   # pylint: disable=protected-access
   457      1000    5636394.8   5636.4     95.6          raybundle = cameras._generate_rays_from_coords(
   458      1000        628.3      0.6      0.0              camera_indices, coords, camera_opt_to_camera, distortion_params_delta, disable_distortion=disable_distortion
   459                                                   )
   460                                           
   461                                                   # If we have mandated that we don't keep the shape, then we flatten
   462      1000        968.3      1.0      0.0          if keep_shape is False:
   463                                                       raybundle = raybundle.flatten()
   464                                           
   465      1000        823.8      0.8      0.0          if aabb_box:
   466                                                       with torch.no_grad():
   467                                                           tensor_aabb = Parameter(aabb_box.aabb.flatten(), requires_grad=False)
   468                                           
   469                                                           rays_o = raybundle.origins.contiguous()
   470                                                           rays_d = raybundle.directions.contiguous()
   471                                           
   472                                                           tensor_aabb = tensor_aabb.to(rays_o.device)
   473                                                           shape = rays_o.shape
   474                                           
   475                                                           rays_o = rays_o.reshape((-1, 3))
   476                                                           rays_d = rays_d.reshape((-1, 3))
   477                                           
   478                                                           t_min, t_max = nerfstudio.utils.math.intersect_aabb(rays_o, rays_d, tensor_aabb)
   479                                           
   480                                                           t_min = t_min.reshape([shape[0], shape[1], 1])
   481                                                           t_max = t_max.reshape([shape[0], shape[1], 1])
   482                                           
   483                                                           raybundle.nears = t_min
   484                                                           raybundle.fars = t_max
   485                                           
   486                                                   # TODO: We should have to squeeze the last dimension here if we started with zero batch dims, but never have to,
   487                                                   # so there might be a rogue squeeze happening somewhere, and this may cause some unintended behaviour
   488                                                   # that we haven't caught yet with tests
   489      1000        386.2      0.4      0.0          return raybundle

Total time: 5.32988 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/cameras/cameras.py
Function: _generate_rays_from_coords at line 492

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   492                                               @profile
   493                                               def _generate_rays_from_coords(
   494                                                   self,
   495                                                   camera_indices: TensorType["num_rays":..., "num_cameras_batch_dims"],
   496                                                   coords: TensorType["num_rays":..., 2],
   497                                                   camera_opt_to_camera: Optional[TensorType["num_rays":..., 3, 4]] = None,
   498                                                   distortion_params_delta: Optional[TensorType["num_rays":..., 6]] = None,
   499                                                   disable_distortion: bool = False,
   500                                               ) -> RayBundle:
   570                                                   # Make sure we're on the right devices
   571      1000      75535.9     75.5      1.4          camera_indices = camera_indices.to(self.device)
   572      1000       4195.6      4.2      0.1          coords = coords.to(self.device)
   573                                           
   574                                                   # Checking to make sure everything is of the right shape and type
   575      1000       2382.9      2.4      0.0          num_rays_shape = camera_indices.shape[:-1]
   576      1000       5181.8      5.2      0.1          assert camera_indices.shape == num_rays_shape + (self.ndim,)
   577      1000       1762.2      1.8      0.0          assert coords.shape == num_rays_shape + (2,)
   578      1000       1123.7      1.1      0.0          assert coords.shape[-1] == 2
   579      1000       1942.3      1.9      0.0          assert camera_opt_to_camera is None or camera_opt_to_camera.shape == num_rays_shape + (3, 4)
   580      1000        633.4      0.6      0.0          assert distortion_params_delta is None or distortion_params_delta.shape == num_rays_shape + (6,)
   581                                           
   582                                                   # Here, we've broken our indices down along the num_cameras_batch_dims dimension allowing us to index by all
   583                                                   # of our output rays at each dimension of our cameras object
   584      1000      15879.8     15.9      0.3          true_indices = [camera_indices[..., i] for i in range(camera_indices.shape[-1])]
   585                                           
   586                                                   # Get all our focal lengths, principal points and make sure they are the right shapes
   587      1000       5208.6      5.2      0.1          y = coords[..., 0]  # (num_rays,) get rid of the last dimension
   588      1000       4509.0      4.5      0.1          x = coords[..., 1]  # (num_rays,) get rid of the last dimension
   589      1000      93689.1     93.7      1.8          fx, fy = self.fx[true_indices].squeeze(-1), self.fy[true_indices].squeeze(-1)  # (num_rays,)
   590      1000      72221.0     72.2      1.4          cx, cy = self.cx[true_indices].squeeze(-1), self.cy[true_indices].squeeze(-1)  # (num_rays,)
   591      1000        984.5      1.0      0.0          assert (
   592      1000       1795.6      1.8      0.0              y.shape == num_rays_shape
   593      1000        868.2      0.9      0.0              and x.shape == num_rays_shape
   594      1000        729.7      0.7      0.0              and fx.shape == num_rays_shape
   595      1000        722.9      0.7      0.0              and fy.shape == num_rays_shape
   596      1000        724.4      0.7      0.0              and cx.shape == num_rays_shape
   597      1000        720.0      0.7      0.0              and cy.shape == num_rays_shape
   598                                                   ), (
   599                                                       str(num_rays_shape)
   600                                                       + str(y.shape)
   601                                                       + str(x.shape)
   602                                                       + str(fx.shape)
   603                                                       + str(fy.shape)
   604                                                       + str(cx.shape)
   605                                                       + str(cy.shape)
   606                                                   )
   607                                           
   608                                                   # Get our image coordinates and image coordinates offset by 1 (offsets used for dx, dy calculations)
   609                                                   # Also make sure the shapes are correct
   610      1000     167456.0    167.5      3.1          coord = torch.stack([(x - cx) / fx, -(y - cy) / fy], -1)  # (num_rays, 2)
   611      1000     166370.8    166.4      3.1          coord_x_offset = torch.stack([(x - cx + 1) / fx, -(y - cy) / fy], -1)  # (num_rays, 2)
   612      1000     158160.7    158.2      3.0          coord_y_offset = torch.stack([(x - cx) / fx, -(y - cy + 1) / fy], -1)  # (num_rays, 2)
   613      1000        920.1      0.9      0.0          assert (
   614      1000       3512.5      3.5      0.1              coord.shape == num_rays_shape + (2,)
   615      1000       1457.5      1.5      0.0              and coord_x_offset.shape == num_rays_shape + (2,)
   616      1000       1287.6      1.3      0.0              and coord_y_offset.shape == num_rays_shape + (2,)
   617                                                   )
   618                                           
   619                                                   # Stack image coordinates and image coordinates offset by 1, check shapes too
   620      1000      32300.6     32.3      0.6          coord_stack = torch.stack([coord, coord_x_offset, coord_y_offset], dim=0)  # (3, num_rays, 2)
   621      1000       2218.1      2.2      0.0          assert coord_stack.shape == (3,) + num_rays_shape + (2,)
   622                                           
   623                                                   # Undistorts our images according to our distortion parameters
   624      1000        568.8      0.6      0.0          if not disable_distortion:
   625      1000        510.2      0.5      0.0              distortion_params = None
   626      1000       1505.4      1.5      0.0              if self.distortion_params is not None:
   627      1000      38145.8     38.1      0.7                  distortion_params = self.distortion_params[true_indices]
   628      1000        714.6      0.7      0.0                  if distortion_params_delta is not None:
   629                                                               distortion_params = distortion_params + distortion_params_delta
   630                                                       elif distortion_params_delta is not None:
   631                                                           distortion_params = distortion_params_delta
   632                                           
   633                                                       # Do not apply distortion for equirectangular images
   634      1000        553.5      0.6      0.0              if distortion_params is not None:
   635      1000      83187.3     83.2      1.6                  mask = (self.camera_type[true_indices] != CameraType.EQUIRECTANGULAR.value).squeeze(-1)  # (num_rays)
   636      1000      35133.6     35.1      0.7                  coord_mask = torch.stack([mask, mask, mask], dim=0)
   637      1000      63283.1     63.3      1.2                  if mask.any():
   638                                                               if distortion_params_delta is not None:
   639                                                                   coord_stack[coord_mask, :] = camera_utils.radial_and_tangential_undistort(
   640      1000    1329823.8   1329.8     25.0                              coord_stack[coord_mask, :].reshape(3, -1, 2),
   641      1000     177964.6    178.0      3.3                              distortion_params[mask, :],
   642      1000     111606.8    111.6      2.1                          ).reshape(-1, 2)
   643      1000        717.2      0.7      0.0                      else:
   644                                                                   # try to use nerfacc to accelerate if we don't need any gradient to optimize distortion params
   645                                                                   try:
   646                                                                       coord_stack[coord_mask, :] = opencv_lens_undistortion(
   647                                                                           coord_stack[coord_mask, :].reshape(3, -1, 2),
   648                                                                           distortion_params[mask, :],
   649                                                                       ).reshape(-1, 2)
   650                                                                   except:
   651                                                                       coord_stack[coord_mask, :] = camera_utils.radial_and_tangential_undistort(
   652      1000       3336.3      3.3      0.1                                  coord_stack[coord_mask, :].reshape(3, -1, 2),
   653                                                                           distortion_params[mask, :],
   654                                                                       ).reshape(-1, 2)
   655                                           
   656                                                   # Make sure after we have undistorted our images, the shapes are still correct
   657                                                   assert coord_stack.shape == (3,) + num_rays_shape + (2,)
   658                                           
   659      1000     141462.6    141.5      2.7          # Gets our directions for all our rays in camera coordinates and checks shapes at the end
   660      1000      17009.2     17.0      0.3          # Here, directions_stack is of shape (3, num_rays, 3)
   661      1000     119748.4    119.7      2.2          # directions_stack[0] is the direction for ray in camera coordinates
   662      1000      84201.4     84.2      1.6          # directions_stack[1] is the direction for ray in camera coordinates offset by 1 in x
   663      1000      42304.3     42.3      0.8          # directions_stack[2] is the direction for ray in camera coordinates offset by 1 in y
   664      1000     302633.4    302.6      5.7          cam_types = torch.unique(self.camera_type, sorted=False)
   665      1000     275161.2    275.2      5.2          directions_stack = torch.empty((3,) + num_rays_shape + (3,), device=self.device)
   666      1000      45150.0     45.1      0.8          if CameraType.PERSPECTIVE.value in cam_types:
   667                                                       mask = (self.camera_type[true_indices] == CameraType.PERSPECTIVE.value).squeeze(-1)  # (num_rays)
   668      1000     101875.8    101.9      1.9              mask = torch.stack([mask, mask, mask], dim=0)
   669                                                       directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0], mask).float()
   670                                                       directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1], mask).float()
   671                                                       directions_stack[..., 2][mask] = -1.0
   672                                           
   673                                                   if CameraType.FISHEYE.value in cam_types:
   674                                                       mask = (self.camera_type[true_indices] == CameraType.FISHEYE.value).squeeze(-1)  # (num_rays)
   675                                                       mask = torch.stack([mask, mask, mask], dim=0)
   676                                           
   677                                                       theta = torch.sqrt(torch.sum(coord_stack**2, dim=-1))
   678                                                       theta = torch.clip(theta, 0.0, math.pi)
   679                                           
   680                                                       sin_theta = torch.sin(theta)
   681      1000      86364.5     86.4      1.6  
   682                                                       directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0] * sin_theta / theta, mask).float()
   683                                                       directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1] * sin_theta / theta, mask).float()
   684                                                       directions_stack[..., 2][mask] = -torch.masked_select(torch.cos(theta), mask).float()
   685                                           
   686                                                   if CameraType.EQUIRECTANGULAR.value in cam_types:
   687                                                       mask = (self.camera_type[true_indices] == CameraType.EQUIRECTANGULAR.value).squeeze(-1)  # (num_rays)
   688                                                       mask = torch.stack([mask, mask, mask], dim=0)
   689                                           
   690                                                       # For equirect, fx = fy = height = width/2
   691                                                       # Then coord[..., 0] goes from -1 to 1 and coord[..., 1] goes from -1/2 to 1/2
   692                                                       theta = -torch.pi * coord_stack[..., 0]  # minus sign for right-handed
   693                                                       phi = torch.pi * (0.5 - coord_stack[..., 1])
   694      1000      22240.3     22.2      0.4              # use spherical in local camera coordinates (+y up, x=0 and z<0 is theta=0)
   695      1000      53516.0     53.5      1.0              directions_stack[..., 0][mask] = torch.masked_select(-torch.sin(theta) * torch.sin(phi), mask).float()
   696                                                       directions_stack[..., 1][mask] = torch.masked_select(torch.cos(phi), mask).float()
   697                                                       directions_stack[..., 2][mask] = torch.masked_select(-torch.cos(theta) * torch.sin(phi), mask).float()
   698      1000       3154.5      3.2      0.1  
   699                                                   for value in cam_types:
   700      1000      43538.8     43.5      0.8              if value not in [CameraType.PERSPECTIVE.value, CameraType.FISHEYE.value, CameraType.EQUIRECTANGULAR.value]:
   701      1000       3587.5      3.6      0.1                  raise ValueError(f"Camera type {value} not supported.")
   702                                           
   703      1000        780.1      0.8      0.0          assert directions_stack.shape == (3,) + num_rays_shape + (3,)
   704      1000     435270.7    435.3      8.2  
   705      1000      12334.7     12.3      0.2          c2w = self.camera_to_worlds[true_indices]
   706      1000       4026.8      4.0      0.1          assert c2w.shape == num_rays_shape + (3, 4)
   707                                           
   708      1000      44612.8     44.6      0.8          if camera_opt_to_camera is not None:
   709      1000      47645.5     47.6      0.9              c2w = pose_utils.multiply(c2w, camera_opt_to_camera)
   710                                                   rotation = c2w[..., :3, :3]  # (..., 3, 3)
   711      1000     195234.7    195.2      3.7          assert rotation.shape == num_rays_shape + (3, 3)
   712      1000       3355.5      3.4      0.1  
   713                                                   directions_stack = torch.sum(
   714      1000      13279.8     13.3      0.2              directions_stack[..., None, :] * rotation, dim=-1
   715      1000       2942.9      2.9      0.1          )  # (..., 1, 3) * (..., 3, 3) -> (..., 3)
   716                                                   directions_stack, directions_norm = camera_utils.normalize_with_norm(directions_stack, -1)
   717      1000       6779.1      6.8      0.1          assert directions_stack.shape == (3,) + num_rays_shape + (3,)
   718      1000       1977.3      2.0      0.0  
   719                                                   origins = c2w[..., :3, 3]  # (..., 3)
   720                                                   assert origins.shape == num_rays_shape + (3,)
   721      1000     145841.0    145.8      2.7  
   722      1000     122159.3    122.2      2.3          directions = directions_stack[0]
   723      1000       2913.3      2.9      0.1          assert directions.shape == num_rays_shape + (3,)
   724                                           
   725      1000      36298.9     36.3      0.7          # norms of the vector going between adjacent coords, giving us dx and dy per output ray
   726      1000       2834.7      2.8      0.1          dx = torch.sqrt(torch.sum((directions - directions_stack[1]) ** 2, dim=-1))  # ("num_rays":...,)
   727                                                   dy = torch.sqrt(torch.sum((directions - directions_stack[2]) ** 2, dim=-1))  # ("num_rays":...,)
   728      1000       2051.9      2.1      0.0          assert dx.shape == num_rays_shape and dy.shape == num_rays_shape
   729                                           
   730      1000     212742.9    212.7      4.0          pixel_area = (dx * dy)[..., None]  # ("num_rays":..., 1)
   731      1000        590.7      0.6      0.0          assert pixel_area.shape == num_rays_shape + (1,)
   732      1000        539.9      0.5      0.0  
   733      1000        621.8      0.6      0.0          times = self.times[camera_indices, 0] if self.times is not None else None
   734      1000        775.3      0.8      0.0  
   735      1000        659.3      0.7      0.0          return RayBundle(
   736      1000      20120.0     20.1      0.4              origins=origins,
   737                                                       directions=directions,
   738                                                       pixel_area=pixel_area,
   739                                                       camera_indices=camera_indices,
   740                                                       times=times,
   741                                                       metadata={"directions_norm": directions_norm[0].detach()},
   742                                                   )

Total time: 13.158 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/data/datamanagers/base_datamanager.py
Function: next_train at line 520

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   520                                               @profiler.time_function
   521                                               @profile
   522                                               def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
   523                                                   """Returns the next batch of data from the train dataloader."""
   524      1000      13811.7     13.8      0.1          self.train_count += 1
   525      1000       4640.8      4.6      0.0          image_batch = next(self.iter_train_image_dataloader)
   526      1000       1072.5      1.1      0.0          assert self.train_pixel_sampler is not None
   527      1000    4149263.4   4149.3     31.5          batch = self.train_pixel_sampler.sample(image_batch)
   528      1000        813.9      0.8      0.0          ray_indices = batch["indices"]
   529      1000    8987908.2   8987.9     68.3          ray_bundle = self.train_ray_generator(ray_indices)
   530      1000        500.8      0.5      0.0          return ray_bundle, batch

Total time: 8.9214 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/model_components/ray_generators.py
Function: forward at line 49

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    49                                               @profiler.time_function
    50                                               @profile
    51                                               def forward(self, ray_indices: TensorType["num_rays", 3]) -> RayBundle:
    57      1000      13560.9     13.6      0.2          c = ray_indices[:, 0]  # camera indices
    58      1000       6918.7      6.9      0.1          y = ray_indices[:, 1]  # row indices
    59      1000       5989.9      6.0      0.1          x = ray_indices[:, 2]  # col indices
    60      1000     162339.4    162.3      1.8          coords = self.image_coords[y, x]
    61                                           
    62      1000       6903.5      6.9      0.1          if self.pose_optimizer is not None:
    63      1000    2747593.1   2747.6     30.8              camera_opt_to_camera = self.pose_optimizer(c)
    64                                                   else:
    65                                                       camera_opt_to_camera = None
    66                                           
    67      1000    5967316.6   5967.3     66.9          ray_bundle = self.cameras.generate_rays(
    68      1000       9650.1      9.7      0.1              camera_indices=c.unsqueeze(-1),
    69      1000        430.7      0.4      0.0              coords=coords,
    70      1000        360.8      0.4      0.0              camera_opt_to_camera=camera_opt_to_camera,
    71                                                   )
    72      1000        331.9      0.3      0.0          return ray_bundle

Total time: 38.7776 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/pipelines/base_pipeline.py
Function: get_train_loss_dict at line 268

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   268                                               @profiler.time_function
   269                                               @profile
   270                                               def get_train_loss_dict(self, step: int):
   278      1000   13193678.2  13193.7     34.0          ray_bundle, batch = self.datamanager.next_train(step)
   279      1000   21417043.0  21417.0     55.2          model_outputs = self.model(ray_bundle)
   280      1000    2052852.5   2052.9      5.3          metrics_dict = self.model.get_metrics_dict(model_outputs, batch)
   281                                           
   282      1000       3657.3      3.7      0.0          if self.config.datamanager.camera_optimizer is not None:
   283      1000       1711.4      1.7      0.0              camera_opt_param_group = self.config.datamanager.camera_optimizer.param_group
   284      1000      41581.3     41.6      0.1              if camera_opt_param_group in self.datamanager.get_param_groups():
   285                                                           # Report the camera optimization metrics
   286      1000        671.0      0.7      0.0                  metrics_dict["camera_opt_translation"] = (
   287      1000     119678.1    119.7      0.3                      self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, :3].norm()
   288                                                           )
   289      1000        576.5      0.6      0.0                  metrics_dict["camera_opt_rotation"] = (
   290      1000      92974.8     93.0      0.2                      self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, 3:].norm()
   291                                                           )
   292                                           
   293      1000    1852501.7   1852.5      4.8          loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict)
   294                                           
   295      1000        646.7      0.6      0.0          return model_outputs, loss_dict, metrics_dict

@liruilong940607
Copy link
Contributor Author

liruilong940607 commented Apr 9, 2023

PR Summary:

1/ Update Instant-NGP (tested) and Nerfplayer-NGP (untested) model with much better performance:

  • Switch from TCNNInstantNGPField to TCNNNerfactoField, which is actively maintained. In this commit the TCNNInstantNGPField is no longer used anymore. Shall we remove it?
  • Now by default it uses multi-level occ grid (4), instead of the single level occ grid with contraction. This is to align the setting in the nerfacc's example, which should be much better for both runtime and quality.
  • Some code cleanup and upgrade such as using lr schedular.

2/ Fix and speedup code for camera undistortion.

  • The original code for camera undistortion (camera_utils._compute_residual_and_jacobian) was uncorrect for both the perspective cameras and fisheye cameras. Specifically, fisheye cameras has a different formulation for distortion & undistortion (OpenCV references: pinhole, fisheye).
  • Also to speedup the iterative undistortion process, the implementations for both perspective camera and fisheye camera are now supported via nerfacc in CUDA (see here and here.

With these two changes, instant-ngp model is now more stable with much better performance than before. The nerfacto model gets not-too-much but noticeable test-time speedup.

This PR is tested with:

CUDA_VISIBLE_DEVICES=3 ns-train instant-ngp --data data/nerfstudio/poster --vis wandb --max-num-iterations 10000
CUDA_VISIBLE_DEVICES=3 ns-train nerfacto --data data/nerfstudio/poster --vis wandb --max-num-iterations 10000

Screen Shot 2023-04-09 at 3 15 33 AM

Screen Shot 2023-04-09 at 3 17 02 AM

Screen Shot 2023-04-09 at 3 17 17 AM

Screen Shot 2023-04-09 at 3 17 34 AM

Screen Shot 2023-04-09 at 3 17 42 AM

@liruilong940607
Copy link
Contributor Author

Not sure what to go with further speedup, because:

  • Precompute all rays would be beneficial only if there is long enough training (so that every ray is trained more than once).
  • Cache rays on-the-fly would easily lead to OOM, partially because there are too many things in RayBundle.

I'll pause the speedup in this PR until we have an ideal of what to do next. Or maybe we merge this one first and leave the speedup to next PR?

Comment on lines +238 to +241
# if self.training:
# loss_dict["distortion_loss"] = self.config.distortion_loss_mult * flatten_eff_distloss(
# outputs["weights"], outputs["steps"], outputs["intervals"], outputs["ray_indices"]
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Comment on lines +218 to +224
# if self.training:
# steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2
# intervals = ray_samples.frustums.ends - ray_samples.frustums.starts
# outputs["weights"] = weights.flatten()
# outputs["steps"] = steps.flatten()
# outputs["intervals"] = intervals.flatten()
# outputs["ray_indices"] = ray_indices.flatten()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

@tancik
Copy link
Contributor

tancik commented Apr 9, 2023

In regards to the Memory Allocation plot, does the nerfacto model now use more GPU memory (the colors are greyed of I can't tell for sure)? If so, why?

@liruilong940607
Copy link
Contributor Author

In regards to the Memory Allocation plot, does the nerfacto model now use more GPU memory (the colors are greyed of I can't tell for sure)? If so, why?

Yeah it shows increased memory. It is very strange so I traced it down and leads to this PR: #1715

@liruilong940607 liruilong940607 mentioned this pull request Apr 20, 2023
@machenmusik
Copy link
Contributor

I'll pause the speedup in this PR until we have an ideal of what to do next. Or maybe we merge this one first and leave the speedup to next PR?

Some thoughts:

  • This seems worthwhile: With these two changes, instant-ngp model is now more stable with much better performance than before. The nerfacto model gets not-too-much but noticeable test-time speedup.
  • Merging may allow other methods to consider similar optimizations.

@SauravMaheshkar SauravMaheshkar added enhancement New feature or request speedup python Pull requests that update Python code labels May 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request python Pull requests that update Python code speedup
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants