Skip to content

Commit

Permalink
Some updates for reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
sgasioro committed Sep 23, 2024
1 parent 20cece8 commit e6fd4be
Show file tree
Hide file tree
Showing 15 changed files with 482 additions and 60 deletions.
1 change: 1 addition & 0 deletions src/gradoptics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from gradoptics.distributions.atom_cloud import AtomCloud
from gradoptics.distributions.simple_atom_cloud import SimpleAtomCloud
from gradoptics.distributions.atom_cloud_donut import AtomCloudDonut
from gradoptics.distributions.atom_cloud_spike import AtomCloudSpike

from gradoptics.inference.rejection_sampling import rejection_sampling

Expand Down
5 changes: 4 additions & 1 deletion src/gradoptics/distributions/atom_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def __init__(self, n=int(1e6), f=2, position=[0.31, 0., 0.], w0=0.0005, h_bar=1.
super().__init__()
self.n = n
self.f = f
self.position = torch.tensor(position)
if isinstance(position, torch.Tensor):
self.position = position
else:
self.position = torch.tensor(position)
self.w0 = w0
self.h_bar = h_bar
self.m = m
Expand Down
156 changes: 156 additions & 0 deletions src/gradoptics/distributions/atom_cloud_spike.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import torch
import math

from gradoptics.distributions.base_distribution import BaseDistribution
from gradoptics.distributions.gaussian_distribution import GaussianDistribution
from gradoptics.inference.rejection_sampling import rejection_sampling

class AtomCloudSpike(BaseDistribution):
"""
Atom cloud "donut" with a hole in the middle. Motivated by discussion with Jason Hogan, et al. and implemented by Sanha.
2D Gaussian in the transverse plane with a central hole "blown away". Cylindrically symmetric. See atom_cloud.py for base
atom cloud definition.
"""

def __init__(self, n=int(1e6), position=[0.31, 0., 0.],
sigma_r_bulk = 0.0005, sigma_r_spike = 0.0001,
sigma_z_bulk = 0.0002, sigma_z_spike = 0.001,
r_mean_bulk = 0, r_mean_spike = 0,
z_mean_bulk = 0, z_mean_spike = 0.001,
mixture_bulk = 0.7,
transverse_proposal=None, longitudinal_proposal=None):
"""
:param n: Number of atoms (:obj:`int`)
:param position: Position of the center of the atom cloud [m] (:obj:`list`)
:param r1: Atom cloud radius (Gaussian std) [m] (:obj:`float`)
:param r2: Donut hole radius (Gaussian std) [m] (:obj:`float`)
:param p: "Power" of the push-away beam (:obj:`float`)
:param sigma_z: Atom cloud thickness (Gaussian std) in longitudinal direction [m] (:obj:`float`)
:param transverse_proposal: Proposal distribution used in rejection sampling
for sampling radial position from the transverse axis
(:py:class:`~gradoptics.distributions.base_distribution.BaseDistribution`)
:param longitudinal_proposal: Proposal distribution used in rejection sampling
for sampling radial position from the transverse axis
(:py:class:`~gradoptics.distributions.base_distribution.BaseDistribution`)
"""
super().__init__()
self.n = n
self.position = torch.tensor(position)
self.sigma_r_bulk = sigma_r_bulk
self.sigma_r_spike = sigma_r_spike

self.sigma_z_bulk = sigma_z_bulk
self.sigma_z_spike = sigma_z_spike


self.r_mean_bulk = r_mean_bulk
self.r_mean_spike = r_mean_spike

self.z_mean_bulk = z_mean_bulk
self.z_mean_spike = z_mean_spike

self.mixture_bulk = mixture_bulk

if transverse_proposal:
self.transverse_proposal = transverse_proposal
else:
self.transverse_proposal = GaussianDistribution(mean=0, std=sigma_r_bulk)

if longitudinal_proposal:
self.longitudinal_proposal = longitudinal_proposal
else:
self.longitudinal_proposal = GaussianDistribution(mean=(z_mean_bulk+z_mean_spike)/2, std=sigma_z_spike+sigma_z_bulk)

# Define a sampler to sample from the cloud density (using rejection sampling)
# self.density_samplers[0] is the transverse, radial sampler
# self.density_samplers[1] is the longitudinal sampler
self.density_samplers = [lambda pdf, nb_point, device: rejection_sampling(pdf, nb_point, self.transverse_proposal,
m=None, device=device),
lambda pdf, nb_point, device: rejection_sampling(pdf, nb_point, self.longitudinal_proposal,
m=None, device=device)
]

def marginal_cloud_density_r(self, r):
"""
Returns the marginal pdf function along the radial axis, evaluated at ``r``
.. warning::
The pdf is unnormalized
:param r: Value where the pdf should be evaluated , in meters (:obj:`torch.tensor`)
:return: The marginal pdf function evaluated at ``r`` (:obj:`torch.tensor`)
"""
r = r.clone().type(torch.float64)

u = self.mixture_bulk * GaussianDistribution(mean=self.r_mean_bulk, std=self.sigma_r_bulk).pdf(r)
v = (1-self.mixture_bulk) * GaussianDistribution(mean=self.r_mean_spike, std=self.sigma_r_spike).pdf(r)

return u + v

def marginal_cloud_density_phi(self, phi):
"""
Returns the marginal pdf function along the azimuthal axis, evaluated at ``phi``
.. warning::
The pdf is unnormalized
:param phi: Value where the pdf should be evaluated , in radians (:obj:`torch.tensor`)
:return: The marginal pdf function evaluated at ``r`` (:obj:`torch.tensor`)
"""
return 1 / (2 * math.pi)

def marginal_cloud_density_z(self, z, z_mean=0):
"""
Returns the marginal pdf function along the longitudinal axis, evaluated at ``z``
.. warning::
The pdf is unnormalized
:param z: Value where the pdf should be evaluated , in meters (:obj:`torch.tensor`)
:return: The marginal pdf function evaluated at ``z`` (:obj:`torch.tensor`)
"""

z = z.clone().type(torch.float64)

u = self.mixture_bulk * GaussianDistribution(mean=self.z_mean_bulk, std=self.sigma_z_bulk).pdf(z)
v = (1-self.mixture_bulk) * GaussianDistribution(mean=self.z_mean_spike, std=self.sigma_z_spike).pdf(z)

return u + v

def pdf(self, x): # @Todo, refractor. x,y,z -> x
"""
Returns the pdf function evaluated at ``x``
.. warning::
The pdf is unnormalized
:param x: Value where the pdf should be evaluated (:obj:`torch.tensor`)
:return: The pdf function evaluated at ``x`` (:obj:`torch.tensor`)
"""
r = torch.sqrt((x[:, 0] - self.position[0]) ** 2 + (x[:, 1] - self.position[1]) ** 2)
phi = torch.atan2(x[:, 1] - self.position[1], x[:, 0] - self.position[0])
return self.marginal_cloud_density_r(r) * \
self.marginal_cloud_density_phi(phi) * \
self.marginal_cloud_density_z(x[:, 2]-self.position[2])

def sample(self, nb_points, device='cpu'):
pass
'''
Currently broken -- should be something like this, though:
atoms = torch.empty((nb_points, 3))
# Sample in the transverse plane
r_tmp = self.density_samplers[0](self.marginal_cloud_density_r, nb_points, device)
phi_tmp = torch.rand(nb_points, device=device) * math.pi
atoms[:, 0] = r_tmp * torch.cos(phi_tmp)
atoms[:, 1] = r_tmp * torch.sin(phi_tmp)
# Sample in the longitudinal axis
tmp = self.density_samplers[1](self.marginal_cloud_density_z, nb_points, device)
atoms[:, 2] = tmp
# Translate the cloud to its expected position
ray_origins = atoms + self.position
del atoms
return ray_origins
'''

def plot(self, ax, **kwargs):
"""
Plots the center of the atom cloud on the provided axes.
:param ax: 3d axes (:py:class:`mpl_toolkits.mplot3d.axes3d.Axes3D`)
"""
ax.scatter(self.position[0], self.position[1], self.position[2], **kwargs)
19 changes: 15 additions & 4 deletions src/gradoptics/integrator/hierarchical_sampling_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ class HierarchicalSamplingIntegrator(BaseIntegrator):
Computes line integrals using hierarchical sampling
"""

def __init__(self, nb_mc_steps, nb_importance_samples, stratify=True):
def __init__(self, nb_mc_steps, nb_importance_samples, stratify=True,
with_color=False, with_var=False):
"""
:param nb_mc_steps: Number of Monte Carlo integration steps used for approximating the integral (:obj:`int`)
"""
self.nb_mc_steps = nb_mc_steps
self.nb_importance_samples = nb_importance_samples
self.stratify = stratify
self.with_color = with_color
self.with_var = with_var

def sample_pdf(self, bins, weights, n_samples, det=False):
# This implementation is from NeRF
Expand Down Expand Up @@ -86,6 +89,14 @@ def compute_integral(self, incident_rays, pdf, t_min, t_max):
x = incident_rays.origins.expand(z_vals_mid.shape[-1], -1, -1).transpose(0, 1) + z_vals_mid.unsqueeze(
-1) * incident_rays.directions.expand(z_vals_mid.shape[-1], -1, -1).transpose(0, 1)

densities = pdf(x.reshape(-1, 3)).reshape((x.shape[:2]))

return (densities * deltas).sum(dim=1)
if self.with_color:
directions = incident_rays.directions.expand(z_vals_mid.shape[-1], -1, -1).transpose(0, 1)
densities = pdf(x.reshape(-1, 3), directions.reshape(-1, 3)).reshape((x.shape[:2]))
else:
densities = pdf(x.reshape(-1, 3)).reshape((x.shape[:2]))

if self.with_var:
variances = pdf(x.reshape(-1, 3), return_var=True).reshape((x.shape[:2]))
return (densities * deltas).sum(dim=1), (variances * (deltas**2)).sum(dim=1)
else:
return (densities * deltas).sum(dim=1)
55 changes: 48 additions & 7 deletions src/gradoptics/light_sources/light_source_from_neural_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ class LightSourceFromNeuralNet(BaseLightSource):
Models a light source from a neural network.
"""

def __init__(self, network, bounding_shape=None, rad=1., x_pos=0., y_pos=0., z_pos=0.):
def __init__(self, network, bounding_shape=None,
rad=1., x_pos=0., y_pos=0., z_pos=0., logpdf=False,
network_var=None):
"""
:param network: Neural network representation of density
:param bounding_shape: A bounding shape that bounds the light source
Expand All @@ -20,6 +22,7 @@ def __init__(self, network, bounding_shape=None, rad=1., x_pos=0., y_pos=0., z_p
A bounding shape is required if this light source is used with backward ray tracing
"""
self.network = network
self.network_var = network_var
self.bounding_shape = bounding_shape

self.rad = rad
Expand All @@ -37,27 +40,65 @@ def __init__(self, network, bounding_shape=None, rad=1., x_pos=0., y_pos=0., z_p
scale_mat[1][1] = 1/self.rad
scale_mat[2][2] = 1/self.rad

self.full_scale_mat = torch.matmul(trans_mat, scale_mat)[:-1]
full_mat = torch.matmul(trans_mat, scale_mat)
self.full_scale_mat = full_mat[:-1]

self.inv_scale_mat = torch.inverse(full_mat)

self.logpdf = logpdf

def sample_rays(self, nb_rays, device='cpu', sample_in_2pi=False):
pass

def plot(self, ax, **kwargs):
pass

def pdf(self, x):
def pdf(self, x, viewdir=None, return_var=False):
"""
Returns the pdf function of the distribution evaluated at ``x``
.. warning::
The pdf may be unnormalized
:param x: Value where the pdf should be evaluated (:obj:`torch.tensor`)
:return: The pdf function evaluated at ``x`` (:obj:`torch.tensor`)
"""

if return_var:
network_here = self.network_var
else:
network_here = self.network

x_scale = torch.matmul(self.full_scale_mat.to(x.device).type(x.dtype),
torch.cat((x, torch.ones((x.shape[0],1),
device=x.device, dtype=x.dtype)), dim=1)[:, :, None]).squeeze(dim=-1)
pdf_val, coords = self.network(x_scale)
self.pdf_val = pdf_val
self.pdf_aux = coords
if viewdir is not None:
viewdir_scale = torch.matmul(self.full_scale_mat.to(viewdir.device).type(viewdir.dtype),
torch.cat((viewdir, torch.zeros((viewdir.shape[0],1),
device=viewdir.device, dtype=viewdir.dtype)), dim=1)[:, :, None]).squeeze(dim=-1)

pdf_val, color, coords = network_here(x_scale, viewdir_scale)
if self.logpdf:
self.pdf_val = torch.exp(pdf_val)
else:
self.pdf_val = pdf_val
self.color = color

if return_var:
self.pdf_aux_var = coords
return self.pdf_val*self.color**2
else:
self.pdf_aux = coords
return self.pdf_val*self.color

return pdf_val
else:
pdf_val, coords = network_here(x_scale)
if self.logpdf:
self.pdf_val = torch.exp(pdf_val)
else:
self.pdf_val = pdf_val

if return_var:
self.pdf_aux_var = coords
else:
self.pdf_aux = coords

return self.pdf_val
2 changes: 2 additions & 0 deletions src/gradoptics/optics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from .lens import PerfectLens, ThickLens
from .mirror import FlatMirror, CurvedMirror
from .bounding_sphere import BoundingSphere
from .bounding_box import BoundingBox
from .bounding_rect import BoundingRect
from .sensor import Sensor
from .vector import vector3d, batch_vector, normalize_batch_vector, normalize_vector
from .window import Window
Expand Down
10 changes: 8 additions & 2 deletions src/gradoptics/optics/bounding_box.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from gradoptics.optics.ray import Rays
from gradoptics.optics import BaseOptics
import torch
import numpy as np

from scipy.spatial import transform



class BoundingBox(BaseOptics):

def __init__(self, width=1e-3, xc=0.2, yc=0.0, zc=0.0, roll=torch.tensor(np.pi/4), pitch=torch.tensor(0.), yaw=torch.tensor(0.)):
def __init__(self, width=1e-3, xc=0.2, yc=0.0, zc=0.0, roll=np.pi/4, pitch=0., yaw=0.):
super().__init__()
self.width = width
self.xc = xc
Expand All @@ -27,7 +32,8 @@ def get_ray_intersection(self, incident_rays, eps=1e-15):
orig_directions = incident_rays.directions

# Instead of rotating the cube, apply the inverse rotation to the rays
inv_rot_mat = rot_mat_3d(self.roll, self.pitch, self.yaw).T
inv_rot_mat = transform.Rotation.from_euler('XYZ', [self.roll, self.pitch, self.yaw]).as_matrix().T
inv_rot_mat = torch.tensor(inv_rot_mat)
expanded_rot = torch.eye(4,dtype=torch.float64)
expanded_rot[:3, :3] = inv_rot_mat

Expand Down
Loading

0 comments on commit e6fd4be

Please sign in to comment.