Issue
If I create a Python dataclass containing a Numpy ndarray, I can no longer use the automatically generated __eq__
anymore.
import numpy as np
@dataclass
class Instr:
foo: np.ndarray
bar: np.ndarray
arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
This is because ndarray.__eq__
sometimes returns a ndarray
of truth values, by comparing a[0]
to b[0]
, and so on and so forth to the longer of the 2. This is quite complex and unintuitive, and in fact only raises an error when the arrays are different shapes, or have different values or something.
How do I safely compare @dataclass
es holding Numpy arrays?
@dataclass
's implementation of __eq__
is generated using eval()
. Its source is missing from the stacktrace and cannot be viewed using inspect
, but it's actually using a tuple comparison, which calls bool(foo).
import dis
dis.dis(Instr.__eq__)
excerpt:
3 12 LOAD_FAST 0 (self) 14 LOAD_ATTR 1 (foo) 16 LOAD_FAST 0 (self) 18 LOAD_ATTR 2 (bar) 20 BUILD_TUPLE 2 22 LOAD_FAST 1 (other) 24 LOAD_ATTR 1 (foo) 26 LOAD_FAST 1 (other) 28 LOAD_ATTR 2 (bar) 30 BUILD_TUPLE 2 32 COMPARE_OP 2 (==) 34 RETURN_VALUE
Solution
The solution is to put in your own __eq__
method and set eq=False
so the dataclass doesn't generate its own (although checking the docs that last step isn't necessary but I think it's nice to be explicit anyway).
import numpy as np
def array_eq(arr1, arr2):
return (isinstance(arr1, np.ndarray) and
isinstance(arr2, np.ndarray) and
arr1.shape == arr2.shape and
(arr1 == arr2).all())
@dataclass(eq=False)
class Instr:
foo: np.ndarray
bar: np.ndarray
def __eq__(self, other):
if not isinstance(other, Instr):
return NotImplemented
return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)
Edit
A general and quick solution for generic dataclasses where some values are numpy arrays and some others are not
import numpy as np
from dataclasses import dataclass, astuple
def array_safe_eq(a, b) -> bool:
"""Check if a and b are equal, even if they are numpy arrays"""
if a is b:
return True
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
return a.shape == b.shape and (a == b).all()
try:
return a == b
except TypeError:
return NotImplemented
def dc_eq(dc1, dc2) -> bool:
"""checks if two dataclasses which hold numpy arrays are equal"""
if dc1 is dc2:
return True
if dc1.__class__ is not dc2.__class__:
return NotImplemented # better than False
t1 = astuple(dc1)
t2 = astuple(dc2)
return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))
# usage
@dataclass(eq=False)
class T:
a: int
b: np.ndarray
c: np.ndarray
def __eq__(self, other):
return dc_eq(self, other)
Answered By - FHTMitchell
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.