Skip to content

Commit

Permalink
More informative structure mismatch error message in Linear layer (#92)
Browse files Browse the repository at this point in the history
* More informative structure mismatch error message in Linear layer

* only add error prefix if parameter

* ignore typing attribute error

---------

Co-authored-by: Ami Falk <[email protected]>
  • Loading branch information
amifalk and ami-cogscai authored Nov 21, 2024
1 parent aac725f commit e22d27f
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions penzai/nn/linear_and_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
from penzai.core import named_axes
from penzai.core import shapecheck
from penzai.core import struct
from penzai.core import variables
from penzai.nn import grouping
from penzai.nn import layer as layer_base
from penzai.nn import parameters

NamedArray = named_axes.NamedArray
Parameter = variables.Parameter
ParameterValue = variables.ParameterValue


class LinearOperatorWeightInitializer(Protocol):
Expand Down Expand Up @@ -421,12 +424,27 @@ class Linear(layer_base.Layer):
def __call__(self, in_array: NamedArray, **_unused_side_inputs) -> NamedArray:
"""Runs the linear operator."""
in_struct = self._input_structure()
dimvars = shapecheck.check_structure(in_array, in_struct)

# pytype: disable=attribute-error
if isinstance(
self.weights,
Parameter | ParameterValue,
) and self.weights.label.endswith(".weights"):
error_prefix = f"({self.weights.label[:-8]}) "
else:
error_prefix = ""
# pytype: enable=attribute-error

dimvars = shapecheck.check_structure(
in_array, in_struct, error_prefix=error_prefix
)

result = contract(self.in_axis_names, in_array, self.weights.value)

out_struct = self._output_structure()
shapecheck.check_structure(result, out_struct, known_vars=dimvars)
shapecheck.check_structure(
result, out_struct, known_vars=dimvars, error_prefix=error_prefix
)
return result

@classmethod
Expand Down

0 comments on commit e22d27f

Please sign in to comment.