-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] profiling and optimize the runtime #1705
base: main
Are you sure you want to change the base?
Conversation
Profiling using line_profiler with commands: CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=5 kernprof -l scripts/train.py nerfacto --data data/nerfstudio/poster --max-num-iterations 1000 --viewer.quit-on-train-completion True Locates Timer unit: 1e-06 s
Total time: 11.8848 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/cameras/cameras.py
Function: generate_rays at line 313
Line # Hits Time Per Hit % Time Line Contents
==============================================================
313 @profiler.time_function
314 @profile
315 def generate_rays( # pylint: disable=too-many-statements
316 self,
317 camera_indices: Union[TensorType["num_rays":..., "num_cameras_batch_dims"], int],
318 coords: Optional[TensorType["num_rays":..., 2]] = None,
319 camera_opt_to_camera: Optional[TensorType["num_rays":..., 3, 4]] = None,
320 distortion_params_delta: Optional[TensorType["num_rays":..., 6]] = None,
321 keep_shape: Optional[bool] = None,
322 disable_distortion: bool = False,
323 aabb_box: Optional[SceneBox] = None,
324 ) -> RayBundle:
325 """Generates rays for the given camera indices.
326
327 This function will standardize the input arguments and then call the _generate_rays_from_coords function
328 to generate the rays. Our goal is to parse the arguments and then get them into the right shape:
329 - camera_indices: (num_rays:..., num_cameras_batch_dims)
330 - coords: (num_rays:..., 2)
331 - camera_opt_to_camera: (num_rays:..., 3, 4) or None
332 - distortion_params_delta: (num_rays:..., 6) or None
333
334 Read the docstring for _generate_rays_from_coords for more information on how we generate the rays
335 after we have standardized the arguments.
336
337 We are only concerned about different combinations of camera_indices and coords matrices, and the following
338 are the 4 cases we have to deal with:
339 1. isinstance(camera_indices, int) and coords == None
340 - In this case we broadcast our camera_indices / coords shape (h, w, 1 / 2 respectively)
341 2. isinstance(camera_indices, int) and coords != None
342 - In this case, we broadcast camera_indices to the same batch dim as coords
343 3. not isinstance(camera_indices, int) and coords == None
344 - In this case, we will need to set coords so that it is of shape (h, w, num_rays, 2), and broadcast
345 all our other args to match the new definition of num_rays := (h, w) + num_rays
346 4. not isinstance(camera_indices, int) and coords != None
347 - In this case, we have nothing to do, only check that the arguments are of the correct shape
348
349 There is one more edge case we need to be careful with: when we have "jagged cameras" (ie: different heights
350 and widths for each camera). This isn't problematic when we specify coords, since coords is already a tensor.
351 When coords == None (ie: when we render out the whole image associated with this camera), we run into problems
352 since there's no way to stack each coordinate map as all coordinate maps are all different shapes. In this case,
353 we will need to flatten each individual coordinate map and concatenate them, giving us only one batch dimension,
354 regardless of the number of prepended extra batch dimensions in the camera_indices tensor.
355
356
357 Args:
358 camera_indices: Camera indices of the flattened cameras object to generate rays for.
359 coords: Coordinates of the pixels to generate rays for. If None, the full image will be rendered.
360 camera_opt_to_camera: Optional transform for the camera to world matrices.
361 distortion_params_delta: Optional delta for the distortion parameters.
362 keep_shape: If None, then we default to the regular behavior of flattening if cameras is jagged, otherwise
363 keeping dimensions. If False, we flatten at the end. If True, then we keep the shape of the
364 camera_indices and coords tensors (if we can).
365 disable_distortion: If True, disables distortion.
366 aabb_box: if not None will calculate nears and fars of the ray according to aabb box intesection
367
368 Returns:
369 Rays for the given camera indices and coords.
370 """
371 # Check the argument types to make sure they're valid and all shaped correctly
372 1000 2485.9 2.5 0.0 assert isinstance(camera_indices, (torch.Tensor, int)), "camera_indices must be a tensor or int"
373 1000 1197.0 1.2 0.0 assert coords is None or isinstance(coords, torch.Tensor), "coords must be a tensor or None"
374 1000 657.0 0.7 0.0 assert camera_opt_to_camera is None or isinstance(camera_opt_to_camera, torch.Tensor)
375 1000 494.9 0.5 0.0 assert distortion_params_delta is None or isinstance(distortion_params_delta, torch.Tensor)
376 1000 852.6 0.9 0.0 if isinstance(camera_indices, torch.Tensor) and isinstance(coords, torch.Tensor):
377 1000 3659.1 3.7 0.0 num_rays_shape = camera_indices.shape[:-1]
378 1000 532.8 0.5 0.0 errormsg = "Batch dims of inputs must match when inputs are all tensors"
379 1000 2080.8 2.1 0.0 assert coords.shape[:-1] == num_rays_shape, errormsg
380 1000 1793.1 1.8 0.0 assert camera_opt_to_camera is None or camera_opt_to_camera.shape[:-2] == num_rays_shape, errormsg
381 1000 558.0 0.6 0.0 assert distortion_params_delta is None or distortion_params_delta.shape[:-1] == num_rays_shape, errormsg
382
383 # If zero dimensional, we need to unsqueeze to get a batch dimension and then squeeze later
384 1000 3330.1 3.3 0.0 if not self.shape:
385 cameras = self.reshape((1,))
386 assert torch.all(
387 torch.tensor(camera_indices == 0) if isinstance(camera_indices, int) else camera_indices == 0
388 ), "Can only index into single camera with no batch dimensions if index is zero"
389 else:
390 1000 413.3 0.4 0.0 cameras = self
391
392 # If the camera indices are an int, then we need to make sure that the camera batch is 1D
393 1000 1104.2 1.1 0.0 if isinstance(camera_indices, int):
394 assert (
395 len(cameras.shape) == 1
396 ), "camera_indices must be a tensor if cameras are batched with more than 1 batch dimension"
397 camera_indices = torch.tensor([camera_indices], device=cameras.device)
398
399 1000 1319.0 1.3 0.0 assert camera_indices.shape[-1] == len(
400 1000 1008.1 1.0 0.0 cameras.shape
401 ), "camera_indices must have shape (num_rays:..., num_cameras_batch_dims)"
402
403 # If keep_shape is True, then we need to make sure that the camera indices in question
404 # are all the same height and width and can actually be batched while maintaining the image
405 # shape
406 1000 634.1 0.6 0.0 if keep_shape is True:
407 assert torch.all(cameras.height[camera_indices] == cameras.height[camera_indices[0]]) and torch.all(
408 cameras.width[camera_indices] == cameras.width[camera_indices[0]]
409 ), "Can only keep shape if all cameras have the same height and width"
410
411 # If the cameras don't all have same height / width, if coords is not none, we will need to generate
412 # a flat list of coords for each camera and then concatenate otherwise our rays will be jagged.
413 # Camera indices, camera_opt, and distortion will also need to be broadcasted accordingly which is non-trivial
414 1000 201546.2 201.5 1.7 if cameras.is_jagged and coords is None and (keep_shape is None or keep_shape is False):
415 index_dim = camera_indices.shape[-1]
416 camera_indices = camera_indices.reshape(-1, index_dim)
417 _coords = [cameras.get_image_coords(index=tuple(index)).reshape(-1, 2) for index in camera_indices]
418 camera_indices = torch.cat(
419 [index.unsqueeze(0).repeat(coords.shape[0], 1) for index, coords in zip(camera_indices, _coords)],
420 )
421 coords = torch.cat(_coords, dim=0)
422 assert coords.shape[0] == camera_indices.shape[0]
423 # Need to get the coords of each indexed camera and flatten all coordinate maps and concatenate them
424
425 # The case where we aren't jagged && keep_shape (since otherwise coords is already set) and coords
426 # is None. In this case we append (h, w) to the num_rays dimensions for all tensors. In this case,
427 # each image in camera_indices has to have the same shape since otherwise we would have error'd when
428 # we checked keep_shape is valid or we aren't jagged.
429 1000 1015.9 1.0 0.0 if coords is None:
430 index_dim = camera_indices.shape[-1]
431 index = camera_indices.reshape(-1, index_dim)[0]
432 coords: torch.Tensor = cameras.get_image_coords(index=tuple(index)) # (h, w, 2)
433 coords = coords.reshape(coords.shape[:2] + (1,) * len(camera_indices.shape[:-1]) + (2,)) # (h, w, 1..., 2)
434 coords = coords.expand(coords.shape[:2] + camera_indices.shape[:-1] + (2,)) # (h, w, num_rays, 2)
435 camera_opt_to_camera = ( # (h, w, num_rays, 3, 4) or None
436 camera_opt_to_camera.broadcast_to(coords.shape[:-1] + (3, 4))
437 if camera_opt_to_camera is not None
438 else None
439 )
440 distortion_params_delta = ( # (h, w, num_rays, 6) or None
441 distortion_params_delta.broadcast_to(coords.shape[:-1] + (6,))
442 if distortion_params_delta is not None
443 else None
444 )
445
446 # If camera indices was an int or coords was none, we need to broadcast our indices along batch dims
447 1000 19520.5 19.5 0.2 camera_indices = camera_indices.broadcast_to(coords.shape[:-1] + (len(cameras.shape),)).to(torch.long)
448
449 # Checking our tensors have been standardized
450 1000 1517.0 1.5 0.0 assert isinstance(coords, torch.Tensor) and isinstance(camera_indices, torch.Tensor)
451 1000 2415.0 2.4 0.0 assert camera_indices.shape[-1] == len(cameras.shape)
452 1000 3257.2 3.3 0.0 assert camera_opt_to_camera is None or camera_opt_to_camera.shape[:-2] == coords.shape[:-1]
453 1000 776.6 0.8 0.0 assert distortion_params_delta is None or distortion_params_delta.shape[:-1] == coords.shape[:-1]
454
455 # This will do the actual work of generating the rays now that we have standardized the inputs
456 # raybundle.shape == (num_rays) when done
457 # pylint: disable=protected-access
458 1000 11629989.5 11630.0 97.9 raybundle = cameras._generate_rays_from_coords(
459 1000 662.4 0.7 0.0 camera_indices, coords, camera_opt_to_camera, distortion_params_delta, disable_distortion=disable_distortion
460 )
461
462 # If we have mandated that we don't keep the shape, then we flatten
463 1000 866.5 0.9 0.0 if keep_shape is False:
464 raybundle = raybundle.flatten()
465
466 1000 746.4 0.7 0.0 if aabb_box:
467 with torch.no_grad():
468 tensor_aabb = Parameter(aabb_box.aabb.flatten(), requires_grad=False)
469
470 rays_o = raybundle.origins.contiguous()
471 rays_d = raybundle.directions.contiguous()
472
473 tensor_aabb = tensor_aabb.to(rays_o.device)
474 shape = rays_o.shape
475
476 rays_o = rays_o.reshape((-1, 3))
477 rays_d = rays_d.reshape((-1, 3))
478
479 t_min, t_max = nerfstudio.utils.math.intersect_aabb(rays_o, rays_d, tensor_aabb)
480
481 t_min = t_min.reshape([shape[0], shape[1], 1])
482 t_max = t_max.reshape([shape[0], shape[1], 1])
483
484 raybundle.nears = t_min
485 raybundle.fars = t_max
486
487 # TODO: We should have to squeeze the last dimension here if we started with zero batch dims, but never have to,
488 # so there might be a rogue squeeze happening somewhere, and this may cause some unintended behaviour
489 # that we haven't caught yet with tests
490 1000 382.9 0.4 0.0 return raybundle
Total time: 11.3149 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/cameras/cameras.py
Function: _generate_rays_from_coords at line 493
Line # Hits Time Per Hit % Time Line Contents
==============================================================
493 @profile
494 def _generate_rays_from_coords(
495 self,
496 camera_indices: TensorType["num_rays":..., "num_cameras_batch_dims"],
497 coords: TensorType["num_rays":..., 2],
498 camera_opt_to_camera: Optional[TensorType["num_rays":..., 3, 4]] = None,
499 distortion_params_delta: Optional[TensorType["num_rays":..., 6]] = None,
500 disable_distortion: bool = False,
501 ) -> RayBundle:
502 """Generates rays for the given camera indices and coords where self isn't jagged
503
504 This is a fairly complex function, so let's break this down slowly.
505
506 Shapes involved:
507 - num_rays: This is your output raybundle shape. It dictates the number and shape of the rays generated
508 - num_cameras_batch_dims: This is the number of dimensions of our camera
509
510 Args:
511 camera_indices: Camera indices of the flattened cameras object to generate rays for.
512 The shape of this is such that indexing into camera_indices["num_rays":...] will return the
513 index into each batch dimension of the camera in order to get the correct camera specified by
514 "num_rays".
515
516 Example:
517 >>> cameras = Cameras(...)
518 >>> cameras.shape
519 (2, 3, 4)
520
521 >>> camera_indices = torch.tensor([0, 0, 0]) # We need an axis of length 3 since cameras.ndim == 3
522 >>> camera_indices.shape
523 (3,)
524 >>> coords = torch.tensor([1,1])
525 >>> coords.shape
526 (2,)
527 >>> out_rays = cameras.generate_rays(camera_indices=camera_indices, coords = coords)
528 # This will generate a RayBundle with a single ray for the
529 # camera at cameras[0,0,0] at image coordinates (1,1), so out_rays.shape == ()
530 >>> out_rays.shape
531 ()
532
533 >>> camera_indices = torch.tensor([[0,0,0]])
534 >>> camera_indices.shape
535 (1, 3)
536 >>> coords = torch.tensor([[1,1]])
537 >>> coords.shape
538 (1, 2)
539 >>> out_rays = cameras.generate_rays(camera_indices=camera_indices, coords = coords)
540 # This will generate a RayBundle with a single ray for the
541 # camera at cameras[0,0,0] at point (1,1), so out_rays.shape == (1,)
542 # since we added an extra dimension in front of camera_indices
543 >>> out_rays.shape
544 (1,)
545
546 If you want more examples, check tests/cameras/test_cameras and the function check_generate_rays_shape
547
548 The bottom line is that for camera_indices: (num_rays:..., num_cameras_batch_dims), num_rays is the
549 output shape and if you index into the output RayBundle with some indices [i:...], if you index into
550 camera_indices with camera_indices[i:...] as well, you will get a 1D tensor containing the batch
551 indices into the original cameras object corresponding to that ray (ie: you will get the camera
552 from our batched cameras corresponding to the ray at RayBundle[i:...]).
553
554 coords: Coordinates of the pixels to generate rays for. If None, the full image will be rendered, meaning
555 height and width get prepended to the num_rays dimensions. Indexing into coords with [i:...] will
556 get you the image coordinates [x, y] of that specific ray located at output RayBundle[i:...].
557
558 camera_opt_to_camera: Optional transform for the camera to world matrices.
559 In terms of shape, it follows the same rules as coords, but indexing into it with [i:...] gets you
560 the 2D camera to world transform matrix for the camera optimization at RayBundle[i:...].
561
562 distortion_params_delta: Optional delta for the distortion parameters.
563 In terms of shape, it follows the same rules as coords, but indexing into it with [i:...] gets you
564 the 1D tensor with the 6 distortion parameters for the camera optimization at RayBundle[i:...].
565
566 disable_distortion: If True, disables distortion.
567
568 Returns:
569 Rays for the given camera indices and coords. RayBundle.shape == num_rays
570 """
571 # Make sure we're on the right devices
572 1000 74631.4 74.6 0.7 camera_indices = camera_indices.to(self.device)
573 1000 3888.6 3.9 0.0 coords = coords.to(self.device)
574
575 # Checking to make sure everything is of the right shape and type
576 1000 2337.8 2.3 0.0 num_rays_shape = camera_indices.shape[:-1]
577 1000 4195.5 4.2 0.0 assert camera_indices.shape == num_rays_shape + (self.ndim,)
578 1000 1698.2 1.7 0.0 assert coords.shape == num_rays_shape + (2,)
579 1000 1052.9 1.1 0.0 assert coords.shape[-1] == 2
580 1000 2006.7 2.0 0.0 assert camera_opt_to_camera is None or camera_opt_to_camera.shape == num_rays_shape + (3, 4)
581 1000 433.5 0.4 0.0 assert distortion_params_delta is None or distortion_params_delta.shape == num_rays_shape + (6,)
582
583 # Here, we've broken our indices down along the num_cameras_batch_dims dimension allowing us to index by all
584 # of our output rays at each dimension of our cameras object
585 1000 14985.0 15.0 0.1 true_indices = [camera_indices[..., i] for i in range(camera_indices.shape[-1])]
586
587 # Get all our focal lengths, principal points and make sure they are the right shapes
588 1000 4790.1 4.8 0.0 y = coords[..., 0] # (num_rays,) get rid of the last dimension
589 1000 4490.4 4.5 0.0 x = coords[..., 1] # (num_rays,) get rid of the last dimension
590 1000 88567.4 88.6 0.8 fx, fy = self.fx[true_indices].squeeze(-1), self.fy[true_indices].squeeze(-1) # (num_rays,)
591 1000 68001.5 68.0 0.6 cx, cy = self.cx[true_indices].squeeze(-1), self.cy[true_indices].squeeze(-1) # (num_rays,)
592 1000 683.8 0.7 0.0 assert (
593 1000 1584.5 1.6 0.0 y.shape == num_rays_shape
594 1000 798.7 0.8 0.0 and x.shape == num_rays_shape
595 1000 715.2 0.7 0.0 and fx.shape == num_rays_shape
596 1000 695.5 0.7 0.0 and fy.shape == num_rays_shape
597 1000 689.5 0.7 0.0 and cx.shape == num_rays_shape
598 1000 682.2 0.7 0.0 and cy.shape == num_rays_shape
599 ), (
600 str(num_rays_shape)
601 + str(y.shape)
602 + str(x.shape)
603 + str(fx.shape)
604 + str(fy.shape)
605 + str(cx.shape)
606 + str(cy.shape)
607 )
608
609 # Get our image coordinates and image coordinates offset by 1 (offsets used for dx, dy calculations)
610 # Also make sure the shapes are correct
611 1000 165038.0 165.0 1.5 coord = torch.stack([(x - cx) / fx, -(y - cy) / fy], -1) # (num_rays, 2)
612 1000 164762.2 164.8 1.5 coord_x_offset = torch.stack([(x - cx + 1) / fx, -(y - cy) / fy], -1) # (num_rays, 2)
613 1000 154118.6 154.1 1.4 coord_y_offset = torch.stack([(x - cx) / fx, -(y - cy + 1) / fy], -1) # (num_rays, 2)
614 1000 662.1 0.7 0.0 assert (
615 1000 3212.4 3.2 0.0 coord.shape == num_rays_shape + (2,)
616 1000 1375.7 1.4 0.0 and coord_x_offset.shape == num_rays_shape + (2,)
617 1000 1231.9 1.2 0.0 and coord_y_offset.shape == num_rays_shape + (2,)
618 )
619
620 # Stack image coordinates and image coordinates offset by 1, check shapes too
621 1000 30781.9 30.8 0.3 coord_stack = torch.stack([coord, coord_x_offset, coord_y_offset], dim=0) # (3, num_rays, 2)
622 1000 2180.3 2.2 0.0 assert coord_stack.shape == (3,) + num_rays_shape + (2,)
623
624 # Undistorts our images according to our distortion parameters
625 1000 560.6 0.6 0.0 if not disable_distortion:
626 1000 498.1 0.5 0.0 distortion_params = None
627 1000 1077.8 1.1 0.0 if self.distortion_params is not None:
628 1000 35734.2 35.7 0.3 distortion_params = self.distortion_params[true_indices]
629 1000 752.8 0.8 0.0 if distortion_params_delta is not None:
630 distortion_params = distortion_params + distortion_params_delta
631 elif distortion_params_delta is not None:
632 distortion_params = distortion_params_delta
633
634 # Do not apply distortion for equirectangular images
635 1000 513.0 0.5 0.0 if distortion_params is not None:
636 1000 79204.1 79.2 0.7 mask = (self.camera_type[true_indices] != CameraType.EQUIRECTANGULAR.value).squeeze(-1) # (num_rays)
637 1000 33447.0 33.4 0.3 coord_mask = torch.stack([mask, mask, mask], dim=0)
638 1000 60928.0 60.9 0.5 if mask.any():
639 1000 7464466.8 7464.5 66.0 coord_stack[coord_mask, :] = camera_utils.radial_and_tangential_undistort(
640 1000 175226.2 175.2 1.5 coord_stack[coord_mask, :].reshape(3, -1, 2),
641 1000 105816.8 105.8 0.9 distortion_params[mask, :],
642 1000 863.3 0.9 0.0 ).reshape(-1, 2)
643
644 # Make sure after we have undistorted our images, the shapes are still correct
645 1000 4192.1 4.2 0.0 assert coord_stack.shape == (3,) + num_rays_shape + (2,)
646
647 # Gets our directions for all our rays in camera coordinates and checks shapes at the end
648 # Here, directions_stack is of shape (3, num_rays, 3)
649 # directions_stack[0] is the direction for ray in camera coordinates
650 # directions_stack[1] is the direction for ray in camera coordinates offset by 1 in x
651 # directions_stack[2] is the direction for ray in camera coordinates offset by 1 in y
652 1000 140261.8 140.3 1.2 cam_types = torch.unique(self.camera_type, sorted=False)
653 1000 16649.5 16.6 0.1 directions_stack = torch.empty((3,) + num_rays_shape + (3,), device=self.device)
654 1000 113806.5 113.8 1.0 if CameraType.PERSPECTIVE.value in cam_types:
655 1000 80066.9 80.1 0.7 mask = (self.camera_type[true_indices] == CameraType.PERSPECTIVE.value).squeeze(-1) # (num_rays)
656 1000 40994.0 41.0 0.4 mask = torch.stack([mask, mask, mask], dim=0)
657 1000 292195.9 292.2 2.6 directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0], mask).float()
658 1000 260874.6 260.9 2.3 directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1], mask).float()
659 1000 43487.6 43.5 0.4 directions_stack[..., 2][mask] = -1.0
660
661 1000 95889.2 95.9 0.8 if CameraType.FISHEYE.value in cam_types:
662 mask = (self.camera_type[true_indices] == CameraType.FISHEYE.value).squeeze(-1) # (num_rays)
663 mask = torch.stack([mask, mask, mask], dim=0)
664
665 theta = torch.sqrt(torch.sum(coord_stack**2, dim=-1))
666 theta = torch.clip(theta, 0.0, math.pi)
667
668 sin_theta = torch.sin(theta)
669
670 directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0] * sin_theta / theta, mask).float()
671 directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1] * sin_theta / theta, mask).float()
672 directions_stack[..., 2][mask] = -torch.masked_select(torch.cos(theta), mask).float()
673
674 1000 81399.7 81.4 0.7 if CameraType.EQUIRECTANGULAR.value in cam_types:
675 mask = (self.camera_type[true_indices] == CameraType.EQUIRECTANGULAR.value).squeeze(-1) # (num_rays)
676 mask = torch.stack([mask, mask, mask], dim=0)
677
678 # For equirect, fx = fy = height = width/2
679 # Then coord[..., 0] goes from -1 to 1 and coord[..., 1] goes from -1/2 to 1/2
680 theta = -torch.pi * coord_stack[..., 0] # minus sign for right-handed
681 phi = torch.pi * (0.5 - coord_stack[..., 1])
682 # use spherical in local camera coordinates (+y up, x=0 and z<0 is theta=0)
683 directions_stack[..., 0][mask] = torch.masked_select(-torch.sin(theta) * torch.sin(phi), mask).float()
684 directions_stack[..., 1][mask] = torch.masked_select(torch.cos(phi), mask).float()
685 directions_stack[..., 2][mask] = torch.masked_select(-torch.cos(theta) * torch.sin(phi), mask).float()
686
687 1000 22204.9 22.2 0.2 for value in cam_types:
688 1000 50056.3 50.1 0.4 if value not in [CameraType.PERSPECTIVE.value, CameraType.FISHEYE.value, CameraType.EQUIRECTANGULAR.value]:
689 raise ValueError(f"Camera type {value} not supported.")
690
691 1000 3335.4 3.3 0.0 assert directions_stack.shape == (3,) + num_rays_shape + (3,)
692
693 1000 40697.1 40.7 0.4 c2w = self.camera_to_worlds[true_indices]
694 1000 3709.2 3.7 0.0 assert c2w.shape == num_rays_shape + (3, 4)
695
696 1000 834.7 0.8 0.0 if camera_opt_to_camera is not None:
697 1000 415472.2 415.5 3.7 c2w = pose_utils.multiply(c2w, camera_opt_to_camera)
698 1000 11014.6 11.0 0.1 rotation = c2w[..., :3, :3] # (..., 3, 3)
699 1000 4103.3 4.1 0.0 assert rotation.shape == num_rays_shape + (3, 3)
700
701 1000 42496.4 42.5 0.4 directions_stack = torch.sum(
702 1000 44016.9 44.0 0.4 directions_stack[..., None, :] * rotation, dim=-1
703 ) # (..., 1, 3) * (..., 3, 3) -> (..., 3)
704 1000 194016.4 194.0 1.7 directions_stack, directions_norm = camera_utils.normalize_with_norm(directions_stack, -1)
705 1000 3257.8 3.3 0.0 assert directions_stack.shape == (3,) + num_rays_shape + (3,)
706
707 1000 11965.4 12.0 0.1 origins = c2w[..., :3, 3] # (..., 3)
708 1000 2850.4 2.9 0.0 assert origins.shape == num_rays_shape + (3,)
709
710 1000 6009.2 6.0 0.1 directions = directions_stack[0]
711 1000 1899.7 1.9 0.0 assert directions.shape == num_rays_shape + (3,)
712
713 # norms of the vector going between adjacent coords, giving us dx and dy per output ray
714 1000 144113.2 144.1 1.3 dx = torch.sqrt(torch.sum((directions - directions_stack[1]) ** 2, dim=-1)) # ("num_rays":...,)
715 1000 113392.9 113.4 1.0 dy = torch.sqrt(torch.sum((directions - directions_stack[2]) ** 2, dim=-1)) # ("num_rays":...,)
716 1000 3008.5 3.0 0.0 assert dx.shape == num_rays_shape and dy.shape == num_rays_shape
717
718 1000 33385.7 33.4 0.3 pixel_area = (dx * dy)[..., None] # ("num_rays":..., 1)
719 1000 2697.6 2.7 0.0 assert pixel_area.shape == num_rays_shape + (1,)
720
721 1000 1534.2 1.5 0.0 times = self.times[camera_indices, 0] if self.times is not None else None
722
723 1000 209191.4 209.2 1.8 return RayBundle(
724 1000 660.2 0.7 0.0 origins=origins,
725 1000 539.5 0.5 0.0 directions=directions,
726 1000 578.5 0.6 0.0 pixel_area=pixel_area,
727 1000 662.9 0.7 0.0 camera_indices=camera_indices,
728 1000 622.1 0.6 0.0 times=times,
729 1000 17423.7 17.4 0.2 metadata={"directions_norm": directions_norm[0].detach()},
730 )
Total time: 18.8698 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/data/datamanagers/base_datamanager.py
Function: next_train at line 514
Line # Hits Time Per Hit % Time Line Contents
==============================================================
514 @profiler.time_function
515 @profile
516 def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
517 """Returns the next batch of data from the train dataloader."""
518 1000 14462.6 14.5 0.1 self.train_count += 1
519 1000 6900.0 6.9 0.0 image_batch = next(self.iter_train_image_dataloader)
520 1000 913.3 0.9 0.0 assert self.train_pixel_sampler is not None
521 1000 4180161.5 4180.2 22.2 batch = self.train_pixel_sampler.sample(image_batch)
522 1000 925.7 0.9 0.0 ray_indices = batch["indices"]
523 1000 14665925.9 14665.9 77.7 ray_bundle = self.train_ray_generator(ray_indices)
524 1000 467.8 0.5 0.0 return ray_bundle, batch
Total time: 14.6002 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/model_components/ray_generators.py
Function: forward at line 42
Line # Hits Time Per Hit % Time Line Contents
==============================================================
42 @profiler.time_function
43 @profile
44 def forward(self, ray_indices: TensorType["num_rays", 3]) -> RayBundle:
45 """Index into the cameras to generate the rays.
46
47 Args:
48 ray_indices: Contains camera, row, and col indices for target rays.
49 """
50 1000 12212.8 12.2 0.1 c = ray_indices[:, 0] # camera indices
51 1000 6210.8 6.2 0.0 y = ray_indices[:, 1] # row indices
52 1000 5762.0 5.8 0.0 x = ray_indices[:, 2] # col indices
53 1000 156186.8 156.2 1.1 coords = self.image_coords[y, x]
54
55 1000 2453782.2 2453.8 16.8 camera_opt_to_camera = self.pose_optimizer(c)
56
57 1000 11956216.8 11956.2 81.9 ray_bundle = self.cameras.generate_rays(
58 1000 8518.4 8.5 0.1 camera_indices=c.unsqueeze(-1),
59 1000 680.0 0.7 0.0 coords=coords,
60 1000 316.4 0.3 0.0 camera_opt_to_camera=camera_opt_to_camera,
61 )
62 1000 341.1 0.3 0.0 return ray_bundle
Total time: 43.9297 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/pipelines/base_pipeline.py
Function: get_train_loss_dict at line 268
Line # Hits Time Per Hit % Time Line Contents
==============================================================
268 @profiler.time_function
269 @profile
270 def get_train_loss_dict(self, step: int):
271 """This function gets your training loss dict. This will be responsible for
272 getting the next batch of data from the DataManager and interfacing with the
273 Model class, feeding the data to the model's forward function.
274
275 Args:
276 step: current iteration step to update sampler if using DDP (distributed)
277 """
278 1000 18911764.8 18911.8 43.1 ray_bundle, batch = self.datamanager.next_train(step)
279 1000 20886524.8 20886.5 47.5 model_outputs = self.model(ray_bundle)
280 1000 2058127.0 2058.1 4.7 metrics_dict = self.model.get_metrics_dict(model_outputs, batch)
281
282 1000 3843.2 3.8 0.0 if self.config.datamanager.camera_optimizer is not None:
283 1000 1244.6 1.2 0.0 camera_opt_param_group = self.config.datamanager.camera_optimizer.param_group
284 1000 45209.3 45.2 0.1 if camera_opt_param_group in self.datamanager.get_param_groups():
285 # Report the camera optimization metrics
286 1000 576.1 0.6 0.0 metrics_dict["camera_opt_translation"] = (
287 1000 121657.4 121.7 0.3 self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, :3].norm()
288 )
289 1000 588.1 0.6 0.0 metrics_dict["camera_opt_rotation"] = (
290 1000 92025.0 92.0 0.2 self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, 3:].norm()
291 )
292
293 1000 1807498.8 1807.5 4.1 loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict)
294
295 1000 612.2 0.6 0.0 return model_outputs, loss_dict, metrics_dict |
Nerfacc takes care of ray undistortion. data loading consumes 43% -> 34% Timer unit: 1e-06 s
Total time: 5.89843 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/cameras/cameras.py
Function: generate_rays at line 312
Line # Hits Time Per Hit % Time Line Contents
==============================================================
312 @profiler.time_function
313 @profile
314 def generate_rays( # pylint: disable=too-many-statements
315 self,
316 camera_indices: Union[TensorType["num_rays":..., "num_cameras_batch_dims"], int],
317 coords: Optional[TensorType["num_rays":..., 2]] = None,
318 camera_opt_to_camera: Optional[TensorType["num_rays":..., 3, 4]] = None,
319 distortion_params_delta: Optional[TensorType["num_rays":..., 6]] = None,
320 keep_shape: Optional[bool] = None,
321 disable_distortion: bool = False,
322 aabb_box: Optional[SceneBox] = None,
323 ) -> RayBundle:
370 # Check the argument types to make sure they're valid and all shaped correctly
371 1000 2542.3 2.5 0.0 assert isinstance(camera_indices, (torch.Tensor, int)), "camera_indices must be a tensor or int"
372 1000 1301.2 1.3 0.0 assert coords is None or isinstance(coords, torch.Tensor), "coords must be a tensor or None"
373 1000 797.3 0.8 0.0 assert camera_opt_to_camera is None or isinstance(camera_opt_to_camera, torch.Tensor)
374 1000 544.7 0.5 0.0 assert distortion_params_delta is None or isinstance(distortion_params_delta, torch.Tensor)
375 1000 1019.2 1.0 0.0 if isinstance(camera_indices, torch.Tensor) and isinstance(coords, torch.Tensor):
376 1000 3863.7 3.9 0.1 num_rays_shape = camera_indices.shape[:-1]
377 1000 463.0 0.5 0.0 errormsg = "Batch dims of inputs must match when inputs are all tensors"
378 1000 1951.7 2.0 0.0 assert coords.shape[:-1] == num_rays_shape, errormsg
379 1000 2065.4 2.1 0.0 assert camera_opt_to_camera is None or camera_opt_to_camera.shape[:-2] == num_rays_shape, errormsg
380 1000 571.5 0.6 0.0 assert distortion_params_delta is None or distortion_params_delta.shape[:-1] == num_rays_shape, errormsg
381
382 # If zero dimensional, we need to unsqueeze to get a batch dimension and then squeeze later
383 1000 3057.7 3.1 0.1 if not self.shape:
384 cameras = self.reshape((1,))
385 assert torch.all(
386 torch.tensor(camera_indices == 0) if isinstance(camera_indices, int) else camera_indices == 0
387 ), "Can only index into single camera with no batch dimensions if index is zero"
388 else:
389 1000 528.9 0.5 0.0 cameras = self
390
391 # If the camera indices are an int, then we need to make sure that the camera batch is 1D
392 1000 1165.2 1.2 0.0 if isinstance(camera_indices, int):
393 assert (
394 len(cameras.shape) == 1
395 ), "camera_indices must be a tensor if cameras are batched with more than 1 batch dimension"
396 camera_indices = torch.tensor([camera_indices], device=cameras.device)
397
398 1000 1319.7 1.3 0.0 assert camera_indices.shape[-1] == len(
399 1000 1130.7 1.1 0.0 cameras.shape
400 ), "camera_indices must have shape (num_rays:..., num_cameras_batch_dims)"
401
402 # If keep_shape is True, then we need to make sure that the camera indices in question
403 # are all the same height and width and can actually be batched while maintaining the image
404 # shape
405 1000 698.3 0.7 0.0 if keep_shape is True:
406 assert torch.all(cameras.height[camera_indices] == cameras.height[camera_indices[0]]) and torch.all(
407 cameras.width[camera_indices] == cameras.width[camera_indices[0]]
408 ), "Can only keep shape if all cameras have the same height and width"
409
410 # If the cameras don't all have same height / width, if coords is not none, we will need to generate
411 # a flat list of coords for each camera and then concatenate otherwise our rays will be jagged.
412 # Camera indices, camera_opt, and distortion will also need to be broadcasted accordingly which is non-trivial
413 1000 207106.9 207.1 3.5 if cameras.is_jagged and coords is None and (keep_shape is None or keep_shape is False):
414 index_dim = camera_indices.shape[-1]
415 camera_indices = camera_indices.reshape(-1, index_dim)
416 _coords = [cameras.get_image_coords(index=tuple(index)).reshape(-1, 2) for index in camera_indices]
417 camera_indices = torch.cat(
418 [index.unsqueeze(0).repeat(coords.shape[0], 1) for index, coords in zip(camera_indices, _coords)],
419 )
420 coords = torch.cat(_coords, dim=0)
421 assert coords.shape[0] == camera_indices.shape[0]
422 # Need to get the coords of each indexed camera and flatten all coordinate maps and concatenate them
423
424 # The case where we aren't jagged && keep_shape (since otherwise coords is already set) and coords
425 # is None. In this case we append (h, w) to the num_rays dimensions for all tensors. In this case,
426 # each image in camera_indices has to have the same shape since otherwise we would have error'd when
427 # we checked keep_shape is valid or we aren't jagged.
428 1000 831.3 0.8 0.0 if coords is None:
429 index_dim = camera_indices.shape[-1]
430 index = camera_indices.reshape(-1, index_dim)[0]
431 coords: torch.Tensor = cameras.get_image_coords(index=tuple(index)) # (h, w, 2)
432 coords = coords.reshape(coords.shape[:2] + (1,) * len(camera_indices.shape[:-1]) + (2,)) # (h, w, 1..., 2)
433 coords = coords.expand(coords.shape[:2] + camera_indices.shape[:-1] + (2,)) # (h, w, num_rays, 2)
434 camera_opt_to_camera = ( # (h, w, num_rays, 3, 4) or None
435 camera_opt_to_camera.broadcast_to(coords.shape[:-1] + (3, 4))
436 if camera_opt_to_camera is not None
437 else None
438 )
439 distortion_params_delta = ( # (h, w, num_rays, 6) or None
440 distortion_params_delta.broadcast_to(coords.shape[:-1] + (6,))
441 if distortion_params_delta is not None
442 else None
443 )
444
445 # If camera indices was an int or coords was none, we need to broadcast our indices along batch dims
446 1000 20603.4 20.6 0.3 camera_indices = camera_indices.broadcast_to(coords.shape[:-1] + (len(cameras.shape),)).to(torch.long)
447
448 # Checking our tensors have been standardized
449 1000 1520.5 1.5 0.0 assert isinstance(coords, torch.Tensor) and isinstance(camera_indices, torch.Tensor)
450 1000 2431.9 2.4 0.0 assert camera_indices.shape[-1] == len(cameras.shape)
451 1000 3095.2 3.1 0.1 assert camera_opt_to_camera is None or camera_opt_to_camera.shape[:-2] == coords.shape[:-1]
452 1000 622.4 0.6 0.0 assert distortion_params_delta is None or distortion_params_delta.shape[:-1] == coords.shape[:-1]
453
454 # This will do the actual work of generating the rays now that we have standardized the inputs
455 # raybundle.shape == (num_rays) when done
456 # pylint: disable=protected-access
457 1000 5636394.8 5636.4 95.6 raybundle = cameras._generate_rays_from_coords(
458 1000 628.3 0.6 0.0 camera_indices, coords, camera_opt_to_camera, distortion_params_delta, disable_distortion=disable_distortion
459 )
460
461 # If we have mandated that we don't keep the shape, then we flatten
462 1000 968.3 1.0 0.0 if keep_shape is False:
463 raybundle = raybundle.flatten()
464
465 1000 823.8 0.8 0.0 if aabb_box:
466 with torch.no_grad():
467 tensor_aabb = Parameter(aabb_box.aabb.flatten(), requires_grad=False)
468
469 rays_o = raybundle.origins.contiguous()
470 rays_d = raybundle.directions.contiguous()
471
472 tensor_aabb = tensor_aabb.to(rays_o.device)
473 shape = rays_o.shape
474
475 rays_o = rays_o.reshape((-1, 3))
476 rays_d = rays_d.reshape((-1, 3))
477
478 t_min, t_max = nerfstudio.utils.math.intersect_aabb(rays_o, rays_d, tensor_aabb)
479
480 t_min = t_min.reshape([shape[0], shape[1], 1])
481 t_max = t_max.reshape([shape[0], shape[1], 1])
482
483 raybundle.nears = t_min
484 raybundle.fars = t_max
485
486 # TODO: We should have to squeeze the last dimension here if we started with zero batch dims, but never have to,
487 # so there might be a rogue squeeze happening somewhere, and this may cause some unintended behaviour
488 # that we haven't caught yet with tests
489 1000 386.2 0.4 0.0 return raybundle
Total time: 5.32988 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/cameras/cameras.py
Function: _generate_rays_from_coords at line 492
Line # Hits Time Per Hit % Time Line Contents
==============================================================
492 @profile
493 def _generate_rays_from_coords(
494 self,
495 camera_indices: TensorType["num_rays":..., "num_cameras_batch_dims"],
496 coords: TensorType["num_rays":..., 2],
497 camera_opt_to_camera: Optional[TensorType["num_rays":..., 3, 4]] = None,
498 distortion_params_delta: Optional[TensorType["num_rays":..., 6]] = None,
499 disable_distortion: bool = False,
500 ) -> RayBundle:
570 # Make sure we're on the right devices
571 1000 75535.9 75.5 1.4 camera_indices = camera_indices.to(self.device)
572 1000 4195.6 4.2 0.1 coords = coords.to(self.device)
573
574 # Checking to make sure everything is of the right shape and type
575 1000 2382.9 2.4 0.0 num_rays_shape = camera_indices.shape[:-1]
576 1000 5181.8 5.2 0.1 assert camera_indices.shape == num_rays_shape + (self.ndim,)
577 1000 1762.2 1.8 0.0 assert coords.shape == num_rays_shape + (2,)
578 1000 1123.7 1.1 0.0 assert coords.shape[-1] == 2
579 1000 1942.3 1.9 0.0 assert camera_opt_to_camera is None or camera_opt_to_camera.shape == num_rays_shape + (3, 4)
580 1000 633.4 0.6 0.0 assert distortion_params_delta is None or distortion_params_delta.shape == num_rays_shape + (6,)
581
582 # Here, we've broken our indices down along the num_cameras_batch_dims dimension allowing us to index by all
583 # of our output rays at each dimension of our cameras object
584 1000 15879.8 15.9 0.3 true_indices = [camera_indices[..., i] for i in range(camera_indices.shape[-1])]
585
586 # Get all our focal lengths, principal points and make sure they are the right shapes
587 1000 5208.6 5.2 0.1 y = coords[..., 0] # (num_rays,) get rid of the last dimension
588 1000 4509.0 4.5 0.1 x = coords[..., 1] # (num_rays,) get rid of the last dimension
589 1000 93689.1 93.7 1.8 fx, fy = self.fx[true_indices].squeeze(-1), self.fy[true_indices].squeeze(-1) # (num_rays,)
590 1000 72221.0 72.2 1.4 cx, cy = self.cx[true_indices].squeeze(-1), self.cy[true_indices].squeeze(-1) # (num_rays,)
591 1000 984.5 1.0 0.0 assert (
592 1000 1795.6 1.8 0.0 y.shape == num_rays_shape
593 1000 868.2 0.9 0.0 and x.shape == num_rays_shape
594 1000 729.7 0.7 0.0 and fx.shape == num_rays_shape
595 1000 722.9 0.7 0.0 and fy.shape == num_rays_shape
596 1000 724.4 0.7 0.0 and cx.shape == num_rays_shape
597 1000 720.0 0.7 0.0 and cy.shape == num_rays_shape
598 ), (
599 str(num_rays_shape)
600 + str(y.shape)
601 + str(x.shape)
602 + str(fx.shape)
603 + str(fy.shape)
604 + str(cx.shape)
605 + str(cy.shape)
606 )
607
608 # Get our image coordinates and image coordinates offset by 1 (offsets used for dx, dy calculations)
609 # Also make sure the shapes are correct
610 1000 167456.0 167.5 3.1 coord = torch.stack([(x - cx) / fx, -(y - cy) / fy], -1) # (num_rays, 2)
611 1000 166370.8 166.4 3.1 coord_x_offset = torch.stack([(x - cx + 1) / fx, -(y - cy) / fy], -1) # (num_rays, 2)
612 1000 158160.7 158.2 3.0 coord_y_offset = torch.stack([(x - cx) / fx, -(y - cy + 1) / fy], -1) # (num_rays, 2)
613 1000 920.1 0.9 0.0 assert (
614 1000 3512.5 3.5 0.1 coord.shape == num_rays_shape + (2,)
615 1000 1457.5 1.5 0.0 and coord_x_offset.shape == num_rays_shape + (2,)
616 1000 1287.6 1.3 0.0 and coord_y_offset.shape == num_rays_shape + (2,)
617 )
618
619 # Stack image coordinates and image coordinates offset by 1, check shapes too
620 1000 32300.6 32.3 0.6 coord_stack = torch.stack([coord, coord_x_offset, coord_y_offset], dim=0) # (3, num_rays, 2)
621 1000 2218.1 2.2 0.0 assert coord_stack.shape == (3,) + num_rays_shape + (2,)
622
623 # Undistorts our images according to our distortion parameters
624 1000 568.8 0.6 0.0 if not disable_distortion:
625 1000 510.2 0.5 0.0 distortion_params = None
626 1000 1505.4 1.5 0.0 if self.distortion_params is not None:
627 1000 38145.8 38.1 0.7 distortion_params = self.distortion_params[true_indices]
628 1000 714.6 0.7 0.0 if distortion_params_delta is not None:
629 distortion_params = distortion_params + distortion_params_delta
630 elif distortion_params_delta is not None:
631 distortion_params = distortion_params_delta
632
633 # Do not apply distortion for equirectangular images
634 1000 553.5 0.6 0.0 if distortion_params is not None:
635 1000 83187.3 83.2 1.6 mask = (self.camera_type[true_indices] != CameraType.EQUIRECTANGULAR.value).squeeze(-1) # (num_rays)
636 1000 35133.6 35.1 0.7 coord_mask = torch.stack([mask, mask, mask], dim=0)
637 1000 63283.1 63.3 1.2 if mask.any():
638 if distortion_params_delta is not None:
639 coord_stack[coord_mask, :] = camera_utils.radial_and_tangential_undistort(
640 1000 1329823.8 1329.8 25.0 coord_stack[coord_mask, :].reshape(3, -1, 2),
641 1000 177964.6 178.0 3.3 distortion_params[mask, :],
642 1000 111606.8 111.6 2.1 ).reshape(-1, 2)
643 1000 717.2 0.7 0.0 else:
644 # try to use nerfacc to accelerate if we don't need any gradient to optimize distortion params
645 try:
646 coord_stack[coord_mask, :] = opencv_lens_undistortion(
647 coord_stack[coord_mask, :].reshape(3, -1, 2),
648 distortion_params[mask, :],
649 ).reshape(-1, 2)
650 except:
651 coord_stack[coord_mask, :] = camera_utils.radial_and_tangential_undistort(
652 1000 3336.3 3.3 0.1 coord_stack[coord_mask, :].reshape(3, -1, 2),
653 distortion_params[mask, :],
654 ).reshape(-1, 2)
655
656 # Make sure after we have undistorted our images, the shapes are still correct
657 assert coord_stack.shape == (3,) + num_rays_shape + (2,)
658
659 1000 141462.6 141.5 2.7 # Gets our directions for all our rays in camera coordinates and checks shapes at the end
660 1000 17009.2 17.0 0.3 # Here, directions_stack is of shape (3, num_rays, 3)
661 1000 119748.4 119.7 2.2 # directions_stack[0] is the direction for ray in camera coordinates
662 1000 84201.4 84.2 1.6 # directions_stack[1] is the direction for ray in camera coordinates offset by 1 in x
663 1000 42304.3 42.3 0.8 # directions_stack[2] is the direction for ray in camera coordinates offset by 1 in y
664 1000 302633.4 302.6 5.7 cam_types = torch.unique(self.camera_type, sorted=False)
665 1000 275161.2 275.2 5.2 directions_stack = torch.empty((3,) + num_rays_shape + (3,), device=self.device)
666 1000 45150.0 45.1 0.8 if CameraType.PERSPECTIVE.value in cam_types:
667 mask = (self.camera_type[true_indices] == CameraType.PERSPECTIVE.value).squeeze(-1) # (num_rays)
668 1000 101875.8 101.9 1.9 mask = torch.stack([mask, mask, mask], dim=0)
669 directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0], mask).float()
670 directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1], mask).float()
671 directions_stack[..., 2][mask] = -1.0
672
673 if CameraType.FISHEYE.value in cam_types:
674 mask = (self.camera_type[true_indices] == CameraType.FISHEYE.value).squeeze(-1) # (num_rays)
675 mask = torch.stack([mask, mask, mask], dim=0)
676
677 theta = torch.sqrt(torch.sum(coord_stack**2, dim=-1))
678 theta = torch.clip(theta, 0.0, math.pi)
679
680 sin_theta = torch.sin(theta)
681 1000 86364.5 86.4 1.6
682 directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0] * sin_theta / theta, mask).float()
683 directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1] * sin_theta / theta, mask).float()
684 directions_stack[..., 2][mask] = -torch.masked_select(torch.cos(theta), mask).float()
685
686 if CameraType.EQUIRECTANGULAR.value in cam_types:
687 mask = (self.camera_type[true_indices] == CameraType.EQUIRECTANGULAR.value).squeeze(-1) # (num_rays)
688 mask = torch.stack([mask, mask, mask], dim=0)
689
690 # For equirect, fx = fy = height = width/2
691 # Then coord[..., 0] goes from -1 to 1 and coord[..., 1] goes from -1/2 to 1/2
692 theta = -torch.pi * coord_stack[..., 0] # minus sign for right-handed
693 phi = torch.pi * (0.5 - coord_stack[..., 1])
694 1000 22240.3 22.2 0.4 # use spherical in local camera coordinates (+y up, x=0 and z<0 is theta=0)
695 1000 53516.0 53.5 1.0 directions_stack[..., 0][mask] = torch.masked_select(-torch.sin(theta) * torch.sin(phi), mask).float()
696 directions_stack[..., 1][mask] = torch.masked_select(torch.cos(phi), mask).float()
697 directions_stack[..., 2][mask] = torch.masked_select(-torch.cos(theta) * torch.sin(phi), mask).float()
698 1000 3154.5 3.2 0.1
699 for value in cam_types:
700 1000 43538.8 43.5 0.8 if value not in [CameraType.PERSPECTIVE.value, CameraType.FISHEYE.value, CameraType.EQUIRECTANGULAR.value]:
701 1000 3587.5 3.6 0.1 raise ValueError(f"Camera type {value} not supported.")
702
703 1000 780.1 0.8 0.0 assert directions_stack.shape == (3,) + num_rays_shape + (3,)
704 1000 435270.7 435.3 8.2
705 1000 12334.7 12.3 0.2 c2w = self.camera_to_worlds[true_indices]
706 1000 4026.8 4.0 0.1 assert c2w.shape == num_rays_shape + (3, 4)
707
708 1000 44612.8 44.6 0.8 if camera_opt_to_camera is not None:
709 1000 47645.5 47.6 0.9 c2w = pose_utils.multiply(c2w, camera_opt_to_camera)
710 rotation = c2w[..., :3, :3] # (..., 3, 3)
711 1000 195234.7 195.2 3.7 assert rotation.shape == num_rays_shape + (3, 3)
712 1000 3355.5 3.4 0.1
713 directions_stack = torch.sum(
714 1000 13279.8 13.3 0.2 directions_stack[..., None, :] * rotation, dim=-1
715 1000 2942.9 2.9 0.1 ) # (..., 1, 3) * (..., 3, 3) -> (..., 3)
716 directions_stack, directions_norm = camera_utils.normalize_with_norm(directions_stack, -1)
717 1000 6779.1 6.8 0.1 assert directions_stack.shape == (3,) + num_rays_shape + (3,)
718 1000 1977.3 2.0 0.0
719 origins = c2w[..., :3, 3] # (..., 3)
720 assert origins.shape == num_rays_shape + (3,)
721 1000 145841.0 145.8 2.7
722 1000 122159.3 122.2 2.3 directions = directions_stack[0]
723 1000 2913.3 2.9 0.1 assert directions.shape == num_rays_shape + (3,)
724
725 1000 36298.9 36.3 0.7 # norms of the vector going between adjacent coords, giving us dx and dy per output ray
726 1000 2834.7 2.8 0.1 dx = torch.sqrt(torch.sum((directions - directions_stack[1]) ** 2, dim=-1)) # ("num_rays":...,)
727 dy = torch.sqrt(torch.sum((directions - directions_stack[2]) ** 2, dim=-1)) # ("num_rays":...,)
728 1000 2051.9 2.1 0.0 assert dx.shape == num_rays_shape and dy.shape == num_rays_shape
729
730 1000 212742.9 212.7 4.0 pixel_area = (dx * dy)[..., None] # ("num_rays":..., 1)
731 1000 590.7 0.6 0.0 assert pixel_area.shape == num_rays_shape + (1,)
732 1000 539.9 0.5 0.0
733 1000 621.8 0.6 0.0 times = self.times[camera_indices, 0] if self.times is not None else None
734 1000 775.3 0.8 0.0
735 1000 659.3 0.7 0.0 return RayBundle(
736 1000 20120.0 20.1 0.4 origins=origins,
737 directions=directions,
738 pixel_area=pixel_area,
739 camera_indices=camera_indices,
740 times=times,
741 metadata={"directions_norm": directions_norm[0].detach()},
742 )
Total time: 13.158 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/data/datamanagers/base_datamanager.py
Function: next_train at line 520
Line # Hits Time Per Hit % Time Line Contents
==============================================================
520 @profiler.time_function
521 @profile
522 def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
523 """Returns the next batch of data from the train dataloader."""
524 1000 13811.7 13.8 0.1 self.train_count += 1
525 1000 4640.8 4.6 0.0 image_batch = next(self.iter_train_image_dataloader)
526 1000 1072.5 1.1 0.0 assert self.train_pixel_sampler is not None
527 1000 4149263.4 4149.3 31.5 batch = self.train_pixel_sampler.sample(image_batch)
528 1000 813.9 0.8 0.0 ray_indices = batch["indices"]
529 1000 8987908.2 8987.9 68.3 ray_bundle = self.train_ray_generator(ray_indices)
530 1000 500.8 0.5 0.0 return ray_bundle, batch
Total time: 8.9214 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/model_components/ray_generators.py
Function: forward at line 49
Line # Hits Time Per Hit % Time Line Contents
==============================================================
49 @profiler.time_function
50 @profile
51 def forward(self, ray_indices: TensorType["num_rays", 3]) -> RayBundle:
57 1000 13560.9 13.6 0.2 c = ray_indices[:, 0] # camera indices
58 1000 6918.7 6.9 0.1 y = ray_indices[:, 1] # row indices
59 1000 5989.9 6.0 0.1 x = ray_indices[:, 2] # col indices
60 1000 162339.4 162.3 1.8 coords = self.image_coords[y, x]
61
62 1000 6903.5 6.9 0.1 if self.pose_optimizer is not None:
63 1000 2747593.1 2747.6 30.8 camera_opt_to_camera = self.pose_optimizer(c)
64 else:
65 camera_opt_to_camera = None
66
67 1000 5967316.6 5967.3 66.9 ray_bundle = self.cameras.generate_rays(
68 1000 9650.1 9.7 0.1 camera_indices=c.unsqueeze(-1),
69 1000 430.7 0.4 0.0 coords=coords,
70 1000 360.8 0.4 0.0 camera_opt_to_camera=camera_opt_to_camera,
71 )
72 1000 331.9 0.3 0.0 return ray_bundle
Total time: 38.7776 s
File: /home/ruilongli/workspace/nerfstudio/nerfstudio/pipelines/base_pipeline.py
Function: get_train_loss_dict at line 268
Line # Hits Time Per Hit % Time Line Contents
==============================================================
268 @profiler.time_function
269 @profile
270 def get_train_loss_dict(self, step: int):
278 1000 13193678.2 13193.7 34.0 ray_bundle, batch = self.datamanager.next_train(step)
279 1000 21417043.0 21417.0 55.2 model_outputs = self.model(ray_bundle)
280 1000 2052852.5 2052.9 5.3 metrics_dict = self.model.get_metrics_dict(model_outputs, batch)
281
282 1000 3657.3 3.7 0.0 if self.config.datamanager.camera_optimizer is not None:
283 1000 1711.4 1.7 0.0 camera_opt_param_group = self.config.datamanager.camera_optimizer.param_group
284 1000 41581.3 41.6 0.1 if camera_opt_param_group in self.datamanager.get_param_groups():
285 # Report the camera optimization metrics
286 1000 671.0 0.7 0.0 metrics_dict["camera_opt_translation"] = (
287 1000 119678.1 119.7 0.3 self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, :3].norm()
288 )
289 1000 576.5 0.6 0.0 metrics_dict["camera_opt_rotation"] = (
290 1000 92974.8 93.0 0.2 self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, 3:].norm()
291 )
292
293 1000 1852501.7 1852.5 4.8 loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict)
294
295 1000 646.7 0.6 0.0 return model_outputs, loss_dict, metrics_dict
|
PR Summary: 1/ Update Instant-NGP (tested) and Nerfplayer-NGP (untested) model with much better performance:
2/ Fix and speedup code for camera undistortion.
With these two changes, This PR is tested with: CUDA_VISIBLE_DEVICES=3 ns-train instant-ngp --data data/nerfstudio/poster --vis wandb --max-num-iterations 10000 CUDA_VISIBLE_DEVICES=3 ns-train nerfacto --data data/nerfstudio/poster --vis wandb --max-num-iterations 10000 |
Not sure what to go with further speedup, because:
I'll pause the speedup in this PR until we have an ideal of what to do next. Or maybe we merge this one first and leave the speedup to next PR? |
# if self.training: | ||
# loss_dict["distortion_loss"] = self.config.distortion_loss_mult * flatten_eff_distloss( | ||
# outputs["weights"], outputs["steps"], outputs["intervals"], outputs["ray_indices"] | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
# if self.training: | ||
# steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2 | ||
# intervals = ray_samples.frustums.ends - ray_samples.frustums.starts | ||
# outputs["weights"] = weights.flatten() | ||
# outputs["steps"] = steps.flatten() | ||
# outputs["intervals"] = intervals.flatten() | ||
# outputs["ray_indices"] = ray_indices.flatten() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
In regards to the |
Yeah it shows increased memory. It is very strange so I traced it down and leads to this PR: #1715 |
Some thoughts:
|
Goal is to identify the bottleneck in the codebase and try to optimize the code for efficiency, to solve #1638