Skip to content

Commit

Permalink
String repr for Basis and DofsView (#779)
Browse files Browse the repository at this point in the history
* Add string representation for AbstractBasis

* String representation for DofsView

* Get rid of edgecolors

* Remove unittest main calls
  • Loading branch information
kinnala authored Nov 13, 2021
1 parent 07bf572 commit 08a695d
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 35 deletions.
2 changes: 1 addition & 1 deletion docs/examples/ex06.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@
from sys import argv
from skfem.visuals.matplotlib import *
ax = draw(m)
plot(M, X, ax=ax, shading='gouraud', edgecolors='')
plot(M, X, ax=ax, shading='gouraud')
savefig(splitext(argv[0])[0] + '_solution.png')
2 changes: 1 addition & 1 deletion docs/examples/ex26.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@

ax = draw(mesh)
plot(mesh, temperature[basis.nodal_dofs.flatten()],
ax=ax, edgecolors='none', colorbar=True)
ax=ax, colorbar=True)
ax.get_figure().savefig(splitext(argv[0])[0] + '_solution.png')
16 changes: 16 additions & 0 deletions skfem/assembly/basis/abstract_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,22 @@ def _get_dofs_normalize_elements(self, elements):
return self.mesh.subdomains[elements]
raise NotImplementedError

def __repr__(self):
size = sum([sum([y.size if hasattr(y, 'size') else 0
for y in x])
for x in self.basis[0]]) * 8 * len(self.basis)
rep = ""
rep += "<skfem {}({}, {}) object>\n".format(type(self).__name__,
type(self.mesh).__name__,
type(self.elem).__name__)
rep += " Number of elements: {}\n".format(self.nelems)
rep += " Number of DOFs: {}\n".format(self.N)
rep += " Size: {} B".format(size)
return rep

def __str__(self):
return self.__repr__()

def default_parameters(self):
"""This is used by :func:`skfem.assembly.asm` to get the default
parameters for 'w'."""
Expand Down
56 changes: 51 additions & 5 deletions skfem/assembly/dofs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union, NamedTuple, Any, List, Optional
from warnings import warn

import numpy as np
from numpy import ndarray
Expand All @@ -20,6 +21,51 @@ class DofsView(NamedTuple):
edge_rows: Union[ndarray, slice] = slice(None)
interior_rows: Union[ndarray, slice] = slice(None)

def __repr__(self):
nnodal = len(np.unique(
self.obj.nodal_dofs[self.nodal_rows][:, self.nodal_ix]
))
nfacet = len(np.unique(
self.obj.facet_dofs[self.facet_rows][:, self.facet_ix]
))
nedge = len(np.unique(
self.obj.edge_dofs[self.edge_rows][:, self.edge_ix]
))
ninterior = len(np.unique(
self.obj.interior_dofs[self.interior_rows][:, self.interior_ix]
))
dofnames = np.array(self.obj.element.dofnames)
rep = ""
rep += "<skfem {}({}, {}) object>\n".format(
type(self).__name__,
type(self.obj.topo).__name__,
type(self.obj.element).__name__
)
rep += " Number of nodal DOFs: {} {}\n".format(
nnodal,
dofnames[self.nodal_rows]
)
rep += " Number of facet DOFs: {} {}\n".format(
nfacet,
dofnames[self.obj.nodal_dofs.shape[0]:][self.facet_rows],
)
if self.obj.topo.dim() > 2:
rep += " Number of edge DOFs: {} {}\n".format(
nedge,
dofnames[(self.obj.nodal_dofs.shape[0]
+ self.obj.facet_dofs.shape[0]):][self.edge_rows],
)
rep += " Number of interior DOFs: {} {}\n".format(
ninterior,
dofnames[(self.obj.nodal_dofs.shape[0]
+ self.obj.facet_dofs.shape[0]
+ self.obj.edge_dofs.shape[0]):][self.interior_rows],
)
return rep

def __str__(self):
return self.__repr__()

def flatten(self) -> ndarray:
"""Return all DOF indices as a single array."""
return np.unique(
Expand Down Expand Up @@ -143,7 +189,7 @@ def __getattr__(self, attr):
return getattr(self.obj, attr)

def __or__(self, other):
"""For merging two sets of DOFs."""
warn("Use numpy.hstack to combine sets of DOFs", DeprecationWarning)
return DofsView(
self.obj,
np.union1d(self.nodal_ix, other.nodal_ix),
Expand Down Expand Up @@ -380,8 +426,8 @@ def check(x, y):
interior_rows.append(i)

return (
nodal_rows if len(nodal_rows) > 0 else slice(0, 0),
facet_rows if len(facet_rows) > 0 else slice(0, 0),
edge_rows if len(edge_rows) > 0 else slice(0, 0),
interior_rows if len(interior_rows) > 0 else slice(0, 0)
np.array(nodal_rows) if len(nodal_rows) > 0 else slice(0, 0),
np.array(facet_rows) if len(facet_rows) > 0 else slice(0, 0),
np.array(edge_rows) if len(edge_rows) > 0 else slice(0, 0),
np.array(interior_rows) if len(interior_rows) > 0 else slice(0, 0)
)
1 change: 0 additions & 1 deletion skfem/visuals/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def plot_meshtri(m: MeshTri1, z: ndarray, **kwargs) -> Axes:
figsize (optional)
Passed on to matplotlib.
shading (optional)
edgecolors (optional)
vmin (optional)
vmax (optional)
Passed on to matplotlib.
Expand Down
4 changes: 0 additions & 4 deletions tests/test_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,3 @@ def proj(v, _):
return ddot(C(sym_grad(x)), v)

y = projection(proj, basis1, basis0)


if __name__ == '__main__':
main()
4 changes: 0 additions & 4 deletions tests/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,3 @@ class FacetConvergenceTetP1(FacetConvergenceTetP2):
case = (MeshTet, ElementTetP1)
limits = (0.9, 1.1)
preref = 2


if __name__ == '__main__':
unittest.main()
4 changes: 0 additions & 4 deletions tests/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,3 @@ def square(w):
)
def test_initialize_dg_composite_elements(e, edg):
E = edg(e) * e


if __name__ == '__main__':
main()
4 changes: 0 additions & 4 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,3 @@ def runTest(self):
import docs.examples.ex42 as ex

self.assertAlmostEqual(ex.x.max(), 0.0009824131638261542, delta=1e-5)


if __name__ == '__main__':
main()
3 changes: 0 additions & 3 deletions tests/test_manufactured.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,4 @@ def func(w):

if __name__ == "__main__":
import pytest
import unittest

unittest.main()
pytest.main()
4 changes: 0 additions & 4 deletions tests/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,3 @@ class TestMortarPairNoMatch2(TestMortarPair):
mesh1_type = MeshQuad
mesh2_type = MeshTri
translate_y = -np.pi / 10.


if __name__ == '__main__':
unittest.main()
4 changes: 0 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,3 @@ def runTest(self):

enforce(A, D=D, overwrite=True)
assert_almost_equal(A.toarray(), np.eye(A.shape[0]))


if __name__ == '__main__':
unittest.main()

0 comments on commit 08a695d

Please sign in to comment.