Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid copy in memory of memmaps #824

Merged
merged 4 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/824.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Substitute np.asarray with np.asanyarray everywhere, to avoid copying memory maps into memory if possible
55 changes: 29 additions & 26 deletions stingray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(cls, *args, **kwargs) -> None:
def main_array_length(self):
if getattr(self, self.main_array_attr, None) is None:
return 0
return np.shape(np.asarray(getattr(self, self.main_array_attr)))[0]
return np.shape(np.asanyarray(getattr(self, self.main_array_attr)))[0]

def data_attributes(self) -> list[str]:
"""Clean up the list of attributes, only giving out those pointing to data.
Expand All @@ -130,7 +130,7 @@ def data_attributes(self) -> list[str]:
and not isinstance(getattr(self.__class__, attr, None), property)
and not callable(value := getattr(self, attr))
and not isinstance(value, StingrayObject)
and not np.asarray(value).dtype == "O"
and not np.asanyarray(value).dtype == "O"
)
]

Expand Down Expand Up @@ -368,7 +368,7 @@ def to_astropy_table(self, no_longdouble=False) -> Table:
array_attrs = self.array_attrs() + [self.main_array_attr] + self.internal_array_attrs()

for attr in array_attrs:
vals = np.asarray(getattr(self, attr))
vals = np.asanyarray(getattr(self, attr))
if no_longdouble:
vals = reduce_precision_if_extended(vals)
data[attr] = vals
Expand Down Expand Up @@ -455,7 +455,7 @@ def to_xarray(self) -> Dataset:
array_attrs = self.array_attrs() + [self.main_array_attr] + self.internal_array_attrs()

for attr in array_attrs:
new_data = np.asarray(getattr(self, attr))
new_data = np.asanyarray(getattr(self, attr))
ndim = len(np.shape(new_data))
if ndim > 1:
new_data = ([attr + f"_dim{i}" for i in range(ndim)], new_data)
Expand Down Expand Up @@ -520,7 +520,7 @@ def to_pandas(self) -> DataFrame:
array_attrs = self.array_attrs() + [self.main_array_attr] + self.internal_array_attrs()

for attr in array_attrs:
values = np.asarray(getattr(self, attr))
values = np.asanyarray(getattr(self, attr))
ndim = len(np.shape(values))
if ndim > 1:
local_data = make_nd_into_arrays(values, attr)
Expand Down Expand Up @@ -758,21 +758,21 @@ def apply_mask(self, mask: npt.ArrayLike, inplace: bool = False, filtered_attrs:
setattr(
new_ts,
"_" + self.main_array_attr,
copy.deepcopy(np.asarray(getattr(self, self.main_array_attr))[mask]),
copy.deepcopy(np.asanyarray(getattr(self, self.main_array_attr))[mask]),
)
else:
setattr(
new_ts,
self.main_array_attr,
copy.deepcopy(np.asarray(getattr(self, self.main_array_attr))[mask]),
copy.deepcopy(np.asanyarray(getattr(self, self.main_array_attr))[mask]),
)

for attr in all_attrs:
if attr not in filtered_attrs:
# Eliminate all unfiltered attributes
setattr(new_ts, attr, None)
else:
setattr(new_ts, attr, copy.deepcopy(np.asarray(getattr(self, attr))[mask]))
setattr(new_ts, attr, copy.deepcopy(np.asanyarray(getattr(self, attr))[mask]))
return new_ts

def _operation_with_other_obj(
Expand Down Expand Up @@ -1030,7 +1030,7 @@ def __neg__(self):

ts_new = copy.deepcopy(self)
for attr in self._default_operated_attrs():
setattr(ts_new, attr, -np.asarray(getattr(self, attr)))
setattr(ts_new, attr, -np.asanyarray(getattr(self, attr)))

return ts_new

Expand Down Expand Up @@ -1215,7 +1215,7 @@ def __init__(
for kw in other_kw:
setattr(self, kw, other_kw[kw])
for kw in array_attrs:
new_arr = np.asarray(array_attrs[kw])
new_arr = np.asanyarray(array_attrs[kw])
if self.time.shape[0] != new_arr.shape[0]:
raise ValueError(f"Lengths of time and {kw} must be equal.")
setattr(self, kw, new_arr)
Expand Down Expand Up @@ -1246,15 +1246,15 @@ def gti(self):
dt1 = self.dt[-1]
else:
dt0 = dt1 = self.dt
self._gti = np.asarray([[self._time[0] - dt0 / 2, self._time[-1] + dt1 / 2]])
self._gti = np.asanyarray([[self._time[0] - dt0 / 2, self._time[-1] + dt1 / 2]])
return self._gti

@gti.setter
def gti(self, value):
if value is None:
self._gti = None
return
value = np.asarray(value)
value = np.asanyarray(value)
self._gti = value
self._mask = None

Expand All @@ -1278,15 +1278,15 @@ def _set_times(self, time, high_precision=False):
return
time, _ = interpret_times(time, self.mjdref)
if not high_precision:
self._time = np.asarray(time)
self._time = np.asanyarray(time)
else:
self._time = np.asarray(time, dtype=np.longdouble)
self._time = np.asanyarray(time, dtype=np.longdouble)

def __str__(self) -> str:
"""Return a string representation of the object."""
return self.pretty_print(
attrs_to_apply=["gti", "time", "tstart", "tseg", "tstop"],
func_to_apply=lambda x: (np.asarray(x) / 86400 + self.mjdref, "MJD"),
func_to_apply=lambda x: (np.asanyarray(x) / 86400 + self.mjdref, "MJD"),
attrs_to_discard=["_mask", "header"],
)

Expand Down Expand Up @@ -1318,7 +1318,7 @@ def _validate_and_format(self, value, attr_name, compare_to_attr):
"""
if value is None:
return None
value = np.asarray(value)
value = np.asanyarray(value)
if len(value.shape) < 1:
raise ValueError(f"{attr_name} array must be at least 1D")
# If the attribute we compare it with is the same and it is currently None, we assign it
Expand Down Expand Up @@ -1446,7 +1446,7 @@ def to_astropy_timeseries(self) -> TimeSeries:
for attr in array_attrs:
if attr == "time":
continue
data[attr] = np.asarray(getattr(self, attr))
data[attr] = np.asanyarray(getattr(self, attr))

if data == {}:
data = None
Expand Down Expand Up @@ -1489,7 +1489,7 @@ def from_astropy_timeseries(cls, ts: TimeSeries) -> StingrayTimeseries:

new_cls = cls()
time, mjdref = interpret_times(time, mjdref)
new_cls.time = np.asarray(time) # type: ignore
new_cls.time = np.asanyarray(time) # type: ignore

array_attrs = ts.colnames
for key, val in ts.meta.items():
Expand All @@ -1498,7 +1498,7 @@ def from_astropy_timeseries(cls, ts: TimeSeries) -> StingrayTimeseries:
for attr in array_attrs:
if attr == "time":
continue
setattr(new_cls, attr, np.asarray(ts[attr]))
setattr(new_cls, attr, np.asanyarray(ts[attr]))

return new_cls

Expand Down Expand Up @@ -1553,11 +1553,11 @@ def shift(self, time_shift: float, inplace=False) -> StingrayTimeseries:
ts = self
else:
ts = copy.deepcopy(self)
ts.time = np.asarray(ts.time) + time_shift # type: ignore
ts.time = np.asanyarray(ts.time) + time_shift # type: ignore
# Pay attention here: if the GTIs are created dynamically while we
# access the property,
if ts._gti is not None:
ts._gti = np.asarray(ts._gti) + time_shift # type: ignore
ts._gti = np.asanyarray(ts._gti) + time_shift # type: ignore

return ts

Expand Down Expand Up @@ -1718,7 +1718,9 @@ def __getitem__(self, index):
delta_gti_start = new_ts.dt[0] * 0.5
delta_gti_stop = new_ts.dt[-1] * 0.5

new_gti = np.asarray([[new_ts.time[0] - delta_gti_start, new_ts.time[-1] + delta_gti_stop]])
new_gti = np.asanyarray(
[[new_ts.time[0] - delta_gti_start, new_ts.time[-1] + delta_gti_stop]]
)
if step > 1 and delta_gti_start > 0:
new_gt1 = np.array(list(zip(new_ts.time - new_ts.dt / 2, new_ts.time + new_ts.dt / 2)))
new_gti = cross_two_gtis(new_gti, new_gt1)
Expand Down Expand Up @@ -1797,7 +1799,8 @@ def _truncate_by_index(self, start, stop):
dtstop = self.dt[-1]

gti = cross_two_gtis(
self.gti, np.asarray([[new_ts.time[0] - 0.5 * dtstart, new_ts.time[-1] + 0.5 * dtstop]])
self.gti,
np.asanyarray([[new_ts.time[0] - 0.5 * dtstart, new_ts.time[-1] + 0.5 * dtstop]]),
)

new_ts.gti = gti
Expand Down Expand Up @@ -2108,7 +2111,7 @@ def rebin(self, dt_new=None, f=None, method="sum"):
elif f is not None:
dt_new = f * self.dt

if np.any(dt_new < np.asarray(self.dt)):
if np.any(dt_new < np.asanyarray(self.dt)):
raise ValueError("The new time resolution must be larger than the old one!")

gti_new = []
Expand Down Expand Up @@ -2150,7 +2153,7 @@ def rebin(self, dt_new=None, f=None, method="sum"):

if len(gti_new) == 0:
raise ValueError("No valid GTIs after rebin.")
new_ts.gti = np.asarray(gti_new)
new_ts.gti = np.asanyarray(gti_new)

for attr in self.meta_attrs():
if attr == "dt":
Expand Down Expand Up @@ -2654,7 +2657,7 @@ def analyze_segments(self, func, segment_size, fraction_step=1, **kwargs):
res = np.nan
else:
lc_filt = self[st:sp]
lc_filt.gti = np.asarray([[tst, tsp]])
lc_filt.gti = np.asanyarray([[tst, tsp]])

res = func(lc_filt, **kwargs)
results.append(res)
Expand Down
12 changes: 6 additions & 6 deletions stingray/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,12 @@ def __init__(
StingrayTimeseries.__init__(
self,
time=time,
energy=None if energy is None else np.asarray(energy),
energy=None if energy is None else np.asanyarray(energy),
mjdref=mjdref,
dt=dt,
notes=notes,
gti=np.asarray(gti) if gti is not None else None,
pi=None if pi is None else np.asarray(pi),
gti=np.asanyarray(gti) if gti is not None else None,
pi=None if pi is None else np.asanyarray(pi),
high_precision=high_precision,
mission=mission,
instr=instr,
Expand Down Expand Up @@ -367,7 +367,7 @@ def to_lc_iter(self, dt, segment_size=None):
self.time[idx_st : idx_end + 1],
dt,
tstart=st,
gti=np.asarray([[st, end]]),
gti=np.asanyarray([[st, end]]),
tseg=tseg,
mjdref=self.mjdref,
use_hist=True,
Expand Down Expand Up @@ -474,8 +474,8 @@ def simulate_energies(self, spectrum, use_spline=False):
return

if isinstance(spectrum, list) or isinstance(spectrum, np.ndarray):
energy = np.asarray(spectrum)[0]
fluxes = np.asarray(spectrum)[1]
energy = np.asanyarray(spectrum)[0]
fluxes = np.asanyarray(spectrum)[1]

if not isinstance(energy, np.ndarray):
raise IndexError("Spectrum must be a 2-d array or list")
Expand Down
8 changes: 4 additions & 4 deletions stingray/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def integrate_power_in_frequency_range(
if power_err is None:
power_err_to_integrate = powers_to_integrate / np.sqrt(m)
else:
power_err_to_integrate = np.asarray(power_err)[frequency_mask]
power_err_to_integrate = np.asanyarray(power_err)[frequency_mask]

power_integrated = np.sum((powers_to_integrate - poisson_power) * dfs_to_integrate)
power_err_integrated = np.sqrt(np.sum((power_err_to_integrate * dfs_to_integrate) ** 2))
Expand Down Expand Up @@ -1250,7 +1250,7 @@ def get_average_ctrate(times, gti, segment_size, counts=None):
Examples
--------
>>> times = np.sort(np.random.uniform(0, 1000, 1000))
>>> gti = np.asarray([[0, 1000]])
>>> gti = np.asanyarray([[0, 1000]])
>>> counts, _ = np.histogram(times, bins=np.linspace(0, 1000, 11))
>>> bin_times = np.arange(50, 1000, 100)
>>> assert get_average_ctrate(bin_times, gti, 1000, counts=counts) == 1.0
Expand Down Expand Up @@ -1326,7 +1326,7 @@ def get_flux_iterable_from_segments(
dt = np.median(np.diff(times[:100]))

if binned:
fluxes = np.asarray(fluxes)
fluxes = np.asanyarray(fluxes)
if np.iscomplexobj(fluxes):
cast_kind = complex

Expand Down Expand Up @@ -2399,7 +2399,7 @@ def lsft_slow(
An array of Fourier transformed data.
"""
y_ = y - np.mean(y)
freqs = np.asarray(freqs[np.asarray(freqs) >= 0])
freqs = np.asanyarray(freqs[np.asanyarray(freqs) >= 0])

ft_real = np.zeros_like(freqs)
ft_imag = np.zeros_like(freqs)
Expand Down
Loading
Loading