Skip to content

Commit

Permalink
dev
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhassell committed Sep 14, 2023
1 parent 26ac71a commit 3f0ad2a
Showing 1 changed file with 93 additions and 92 deletions.
185 changes: 93 additions & 92 deletions cf/regrid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,8 +806,9 @@ def spherical_grid(f, name=None, method=None, cyclic=None, axes=None):
dim_coords_1d = False
aux_coords_2d = False
aux_coords_1d = False
domain_topology = None

domain_topology= None
mesh_location = None

# Look for 1-d X and Y dimension coordinates
lon_key_1d, lon_1d = f.dimension_coordinate(
"X", item=True, default=(None, None)
Expand Down Expand Up @@ -874,7 +875,7 @@ def spherical_grid(f, name=None, method=None, cyclic=None, axes=None):
"coordinates"
)
else:
domain_topology, mesh_axis = ffff(f, name)
domain_topology, mesh_axis, mesh_location = ffff(f, name)

# Look for 1-d X and Y auxiliary coordinates
auxs = f.auxiliary_coordinates(
Expand Down Expand Up @@ -963,13 +964,14 @@ def spherical_grid(f, name=None, method=None, cyclic=None, axes=None):
axes={"X": x_axis, "Y": y_axis},
shape=axis_sizes,
coords=coords,
bounds=get_bounds(method, coords, mesh),
bounds=get_bounds(method, coords, mesh_location),
cyclic=bool(cyclic),
coord_sys="spherical",
method=method,
name=name,
mesh=mesh,
domain_topology=domain_topology
domain_topology=domain_topology,
mesh_location=mesh_location
)

#def spherical_mesh(f, name=None, method=None):
Expand Down Expand Up @@ -1140,7 +1142,7 @@ def Cartesian_grid(f, name=None, method=None, axes=None, mesh=False):

coords.append(coord)

bounds = get_bounds(method, coords)
bounds = get_bounds(method, coords, mesh_location)

dummy_size_2_dimension = False
if len(coords) == 1:
Expand Down Expand Up @@ -1592,7 +1594,7 @@ def create_esmpy_grid(grid=None, mask=None):
return esmpy_grid


def create_esmpy_mesh(mesh=None, mask=None):
def create_esmpy_mesh(grid=None, mask=None):
"""Create an `esmpy` Mesh.
.. versionadded:: UGRIDVER
Expand All @@ -1601,7 +1603,7 @@ def create_esmpy_mesh(mesh=None, mask=None):
:Parameters:
mesh: `Mesh`
grid: `Grid`
mask: array_like, optional
The mesh mask. If `None` (the default) then there are no
Expand All @@ -1613,16 +1615,7 @@ def create_esmpy_mesh(mesh=None, mask=None):
`esmpy.Mesh`
"""
# shape = grid.domain_topology.data.shape
# if shape[1] > 9:
# # This restriction is because esmpy.api.constants.MeshElemType
# # values of 10 or more are used for 3-d volume mesh elements.
# raise ValueError(f"TODOUGRID: Can't regrid mesh elements with more than 9 nodes: got {grid.domain_topology.data.shape[1]}")
#
spherical = False
if grid.coord_sys == "spherical":
spherical = True
x, y = 0, 1
coord_sys = esmpy.CoordSys.SPH_DEG
else:
# Cartesian
Expand All @@ -1634,68 +1627,69 @@ def create_esmpy_mesh(mesh=None, mask=None):
spatial_dim=2,
coord_sys=coord_sys
)

element_conn = grid.domain_topology.array
element_count = element_conn.shape[0]
element_types = np.ma.count(element_conn, axis=1)
element_conn = np.ma.compressed(element_conn)

node_ids, index = np.unique(element_conn , return_index=True)

node_coords = [np.ma.compressed(b)[index] for b in grid.bounds]
node_coords = np.stack(node_coords, axis=-1)

for aux in node_coordinates:
if aux.X:
x = aux
elif aux.Y:
y = aux

node_coords = np.stack((x.array, y.array), axis=-1)
node_count = node_coords.shape[0]
node_count = node_ids.size
node_owners = np.zeros(node_count)

# This must be done before add_elements
# This must be done before `add_elements`
esmpy_mesh.add_nodes(
node_count =node_count,
node_ids =np.arange(node_count)
node_ids =node_ids ,
node_coords= node_coords,
node_owners= node_owners
)

element_conn = grid.domain_topology.array
element_count = element_conn.shape[0]
element_types = np.ma.count(element_conn, axis=1)
element_conn = np.ma.compressed(element_conn)

# ddd = np.unique(element_types ).tolist():
# tmin = ddd.
# tmax = ddd.max()
# if ddd[0] < 3:
# raise ValueError("TODOUGRID")
#
# if ddd[-1] > 9:
# raise ValueError("TODOUGRID")
#
# for MeshElemType in ddd:
# element_types = np.where(
# element_types == 3, MeshElemType, element_types
# )
#
#
# if 3 in (tmin, tmax):
# element_types = np.where(
# element_types == 3, esmpy.MeshElemType.TRI, element_types
# )
#
# if 4 in (tmin, tmax):
# element_types = np.where(
# element_types == 4, esmpy.MeshElemType.QUAD, element_types
# )
# Element coordinates
if grid.coords:
try:
element_coords = [np.asanyarray(c) for c in grid.coords]
except ValueError:
# The coordinate constructs have no data
element_coords = None
else:
element_coords = np.stack(element_coords, axis=-1)
else:
element_coords =None

x, y = grid.coords
element_coords = np.stack((x.array, y.array), axis=-1)

# This must be done after add_nodes
mesh.add_elements(
element_count=element_count
element_ids=np.arange(elem_count)
# Mask
if mask is not None:
if mask.dtype != bool:
raise ValueError(
"'mask' must be None or a Boolean array. "
f"Got: dtype={mask.dtype}"
)

# Note: 'mask' has True/False for masked/unmasked
# elements, but the esmpy mask requires 0/1 for
# masked/unmasked elements.
mask = np.invert(mask).astype("int32")
if mask.all():
# There are no masked elements
mask = None

# This must be done after `add_nodes`
esmpy_mesh.add_elements(
element_count=element_count,
element_ids=np.arange(elem_count),
element_types=element_types,
element_conn=element_conn,
element_mask=mask,
element_area=None,
element_coords=element_coords,
element_mask=None,
)


return esmpy_mesh


def create_esmpy_weights(
method,
Expand Down Expand Up @@ -1804,8 +1798,22 @@ def create_esmpy_weights(
# Create the weights using ESMF
from_file = False

src_esmpy_field = esmpy.Field(src_esmpy_grid, "src")
dst_esmpy_field = esmpy.Field(dst_esmpy_grid, "dst")
if src_grid.mesh_location =="face":
src_meshloc = esmpy.api.constants.MeshLoc.ELEMENT
elif src_grid.mesh_location =="node":
src_meshloc = esmpy.api.constants.MeshLoc.NODE
else:
src_meshloc = None

if dst_grid.mesh_location =="face":
dst_meshloc = esmpy.api.constants.MeshLoc.ELEMENT
elif dst_grid.mesh_location =="node":
dst_meshloc = esmpy.api.constants.MeshLoc.NODE
else:
dst_meshloc = None

src_esmpy_field = esmpy.Field(src_esmpy_grid, "src", meshloc=src_meshloc)
dst_esmpy_field = esmpy.Field(dst_esmpy_grid, "dst", meshloc=dst_meshloc)

mask_values = np.array([0], dtype="int32")

Expand Down Expand Up @@ -1984,7 +1992,7 @@ def contiguous_bounds(b, cyclic=False, period=None):
return True


def get_bounds(method, coords):
def get_bounds(method, coords, mesh_location):
"""Get coordinate bounds needed for defining an `esmpy.Grid`.
.. versionadded:: 3.14.0
Expand All @@ -1997,7 +2005,7 @@ def get_bounds(method, coords):
coords: sequence of `Coordinate`
The coordinates that define an `esmpy.Grid`.
mesh: `bool`
mesh_location: `str` or `None`
TODOUGRID
.. versionadded:: UGRIDVER
Expand All @@ -2009,12 +2017,15 @@ def get_bounds(method, coords):
regridding method is not conservative. TODOUGRID
"""
if not mesh and not conservative_regridding(method):
if mesh_location == "node":
return []

if not conservative_regridding(method):
return []

bounds = [c.get_bounds(None) for c in coords]
for c, b in zip(coords, bounds):
if b is None:
if b is None :
raise ValueError(
f"All coordinates must have bounds for {method!r} "
f"regridding: {c!r}"
Expand Down Expand Up @@ -2070,7 +2081,9 @@ def get_mask(f, grid):
mask = mask[tuple(index)]

# Reorder the mask axes to grid.axes_keys
mask = da.transpose(mask, axes=np.argsort(regrid_axes).tolist())
axes = np.argsort(regrid_axes).tolist()
if len(axes) > 1:
mask = da.transpose(mask, axes=axes)

return mask

Expand Down Expand Up @@ -2264,28 +2277,16 @@ def update_non_coordinates(
src.set_coordinate_reference(ref, parent=dst, strict=True)

def ffff(f, name):

mesh = False
mesh_axes = None

"""TODOUGRID"""
key, domain_topology = f.domain_topology(item=True, default=(None, None))
if domain_topology is None:
raise ValueError("TODOUGRID")

cell_type= domain_topology.get_cell(None)
mesh = cell_type== "face"
if not mesh:
if name == "src":
raise ValueError(
"Can't regrid data defined on an unstructured mesh of "
f"{cell_type!r} cells"
)
elif name == "dst":
raise ValueError(
"Can't regrid data to an unstructured mesh of "
f"{cell_type!r} cells"
)
mesh_location = domain_topology.get_cell(None)
if mesh_location not in ("face", "node"):
raise ValueError(
"Can't regrid when the {name} grid is an unstructured mesh "
f"of {mesh_location!r} cells"
)

mesh_axes = f.get_data_axes(key)

return mesh, mesh_axes
return domain_topology, f.get_data_axes(key), mesh_location

0 comments on commit 3f0ad2a

Please sign in to comment.