Skip to content

Commit

Permalink
Add function to get the ufl.Cell corresponding to a cell's facet (#76)
Browse files Browse the repository at this point in the history
* Add method for getting the facet cell

* Don't use cellname2facetname

* Remove num_facet_edges

* Make suggested changes
  • Loading branch information
jpdean authored Jan 10, 2022
1 parent 1de7f12 commit 7a4303f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
18 changes: 13 additions & 5 deletions ufl/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numbers
import functools

import ufl.cell
from ufl.log import error
from ufl.core.ufl_type import attach_operators_from_hash_data

Expand Down Expand Up @@ -95,6 +96,7 @@ def __lt__(self, other):
# Mapping from cell name to facet name
# Note: This is not generalizable to product elements but it's still
# in use a couple of places.
# TODO Remove
cellname2facetname = {"interval": "vertex",
"triangle": "interval",
"quadrilateral": "interval",
Expand Down Expand Up @@ -164,11 +166,17 @@ def num_facets(self):

# --- Facet properties ---

def num_facet_edges(self):
"The number of facet edges."
# This is used in geometry.py
fn = cellname2facetname[self.cellname()]
return num_cell_entities[fn][1]
def facet_types(self):
"A tuple of ufl.Cell representing the facets of self."
# TODO Move outside method?
facet_type_names = {"interval": ("vertex",),
"triangle": ("interval",),
"quadrilateral": ("interval",),
"tetrahedron": ("triangle",),
"hexahedron": ("quadrilateral",),
"prism": ("triangle", "quadrilateral")}
return tuple(ufl.Cell(facet_name, self.geometric_dimension())
for facet_name in facet_type_names[self.cellname()])

# --- Special functions for proper object behaviour ---

Expand Down
16 changes: 14 additions & 2 deletions ufl/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,13 @@ def __init__(self, domain):
@property
def ufl_shape(self):
cell = self.ufl_domain().ufl_cell()
nfe = cell.num_facet_edges()
facet_types = cell.facet_types()

# Raise exception for cells with more than one facet type e.g. prisms
if len(facet_types) > 1:
raise Exception(f"Cell type {cell} not supported.")

nfe = facet_types[0].num_edges()
t = cell.topological_dimension()
return (nfe, t)

Expand Down Expand Up @@ -470,7 +476,13 @@ def __init__(self, domain):
@property
def ufl_shape(self):
cell = self.ufl_domain().ufl_cell()
nfe = cell.num_facet_edges()
facet_types = cell.facet_types()

# Raise exception for cells with more than one facet type e.g. prisms
if len(facet_types) > 1:
raise Exception(f"Cell type {cell} not supported.")

nfe = facet_types[0].num_edges()
g = cell.geometric_dimension()
return (nfe, g)

Expand Down

0 comments on commit 7a4303f

Please sign in to comment.