diff --git a/stonesoup/sensormanager/reward.py b/stonesoup/sensormanager/reward.py index f65a930e4..670375866 100644 --- a/stonesoup/sensormanager/reward.py +++ b/stonesoup/sensormanager/reward.py @@ -95,8 +95,6 @@ def __call__(self, config: Mapping[Sensor, Sequence[Action]], tracks: Set[Track] predicted_tracks = set() for track in tracks: predicted_track = copy.copy(track) - predicted_track.states = copy.copy(predicted_track.states) - predicted_track.metadatas = copy.copy(predicted_track.metadatas) predicted_track.append(self.predictor.predict(predicted_track, timestamp=metric_time)) predicted_tracks.add(predicted_track) diff --git a/stonesoup/types/state.py b/stonesoup/types/state.py index 78274e5ac..2e3cca889 100644 --- a/stonesoup/types/state.py +++ b/stonesoup/types/state.py @@ -1,3 +1,4 @@ +import copy import datetime import uuid from collections import abc @@ -172,6 +173,11 @@ class StateMutableSequence(Type, abc.MutableSequence): proxying state attributes to the last state in the sequence. This sequence can also be indexed/sliced by :class:`datetime.datetime` instances. + Notes + ----- + If shallow copying, similar to a list, it is safe to add/remove states + without affecting the original sequence. + Example ------- >>> t0 = datetime.datetime(2018, 1, 1, 14, 00) @@ -261,6 +267,13 @@ def __getattribute__(self, name): # raise the original error instead raise original_error + def __copy__(self): + inst = self.__class__.__new__(self.__class__) + inst.__dict__.update(self.__dict__) + property_name = self.__class__.states._property_name + inst.__dict__[property_name] = copy.copy(self.__dict__[property_name]) + return inst + def insert(self, index, value): return self.states.insert(index, value) diff --git a/stonesoup/types/tests/test_state.py b/stonesoup/types/tests/test_state.py index e9f5636ea..fe7b06836 100644 --- a/stonesoup/types/tests/test_state.py +++ b/stonesoup/types/tests/test_state.py @@ -1,3 +1,4 @@ +import copy import datetime import numpy as np @@ -516,6 +517,23 @@ def complicated_attribute(self): _ = test_obj.complicated_attribute +def test_state_mutable_sequence_copy(): + state_vector = StateVector([[0]]) + timestamp = datetime.datetime(2018, 1, 1, 14) + delta = datetime.timedelta(minutes=1) + sequence = StateMutableSequence( + [State(state_vector, timestamp=timestamp+delta*n) + for n in range(10)]) + + sequence2 = copy.copy(sequence) + + assert sequence2.states is not sequence.states + + assert sequence2[-1] is sequence[-1] + sequence2.remove(sequence[-1]) + assert sequence2[-1] is not sequence[-1] + + def test_from_state(): start = datetime.datetime.now() kwargs = {"state_vector": np.arange(4), "timestamp": start} diff --git a/stonesoup/types/track.py b/stonesoup/types/track.py index 5b5267a9a..99baa22ee 100644 --- a/stonesoup/types/track.py +++ b/stonesoup/types/track.py @@ -1,3 +1,4 @@ +import copy import uuid from typing import MutableSequence, MutableMapping @@ -47,6 +48,11 @@ def __setitem__(self, index, value): index = len(self.states) + index self._update_metadatas(index) + def __copy__(self): + inst = super().__copy__() + inst.__dict__['metadatas'] = copy.copy(self.__dict__['metadatas']) + return inst + def insert(self, index, value): """Insert value at index of :attr:`states`.