Skip to content

Commit

Permalink
Merge pull request #696 from dstl/track_copy
Browse files Browse the repository at this point in the history
Enable proper shallow copying of state mutable sequences
  • Loading branch information
sdhiscocks authored Aug 22, 2022
2 parents 7e6c19f + 368cc20 commit 2618c0b
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 2 deletions.
2 changes: 0 additions & 2 deletions stonesoup/sensormanager/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ def __call__(self, config: Mapping[Sensor, Sequence[Action]], tracks: Set[Track]
predicted_tracks = set()
for track in tracks:
predicted_track = copy.copy(track)
predicted_track.states = copy.copy(predicted_track.states)
predicted_track.metadatas = copy.copy(predicted_track.metadatas)
predicted_track.append(self.predictor.predict(predicted_track, timestamp=metric_time))
predicted_tracks.add(predicted_track)

Expand Down
13 changes: 13 additions & 0 deletions stonesoup/types/state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import datetime
import uuid
from collections import abc
Expand Down Expand Up @@ -172,6 +173,11 @@ class StateMutableSequence(Type, abc.MutableSequence):
proxying state attributes to the last state in the sequence. This sequence
can also be indexed/sliced by :class:`datetime.datetime` instances.
Notes
-----
If shallow copying, similar to a list, it is safe to add/remove states
without affecting the original sequence.
Example
-------
>>> t0 = datetime.datetime(2018, 1, 1, 14, 00)
Expand Down Expand Up @@ -261,6 +267,13 @@ def __getattribute__(self, name):
# raise the original error instead
raise original_error

def __copy__(self):
inst = self.__class__.__new__(self.__class__)
inst.__dict__.update(self.__dict__)
property_name = self.__class__.states._property_name
inst.__dict__[property_name] = copy.copy(self.__dict__[property_name])
return inst

def insert(self, index, value):
return self.states.insert(index, value)

Expand Down
18 changes: 18 additions & 0 deletions stonesoup/types/tests/test_state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import datetime

import numpy as np
Expand Down Expand Up @@ -516,6 +517,23 @@ def complicated_attribute(self):
_ = test_obj.complicated_attribute


def test_state_mutable_sequence_copy():
state_vector = StateVector([[0]])
timestamp = datetime.datetime(2018, 1, 1, 14)
delta = datetime.timedelta(minutes=1)
sequence = StateMutableSequence(
[State(state_vector, timestamp=timestamp+delta*n)
for n in range(10)])

sequence2 = copy.copy(sequence)

assert sequence2.states is not sequence.states

assert sequence2[-1] is sequence[-1]
sequence2.remove(sequence[-1])
assert sequence2[-1] is not sequence[-1]


def test_from_state():
start = datetime.datetime.now()
kwargs = {"state_vector": np.arange(4), "timestamp": start}
Expand Down
6 changes: 6 additions & 0 deletions stonesoup/types/track.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import uuid
from typing import MutableSequence, MutableMapping

Expand Down Expand Up @@ -47,6 +48,11 @@ def __setitem__(self, index, value):
index = len(self.states) + index
self._update_metadatas(index)

def __copy__(self):
inst = super().__copy__()
inst.__dict__['metadatas'] = copy.copy(self.__dict__['metadatas'])
return inst

def insert(self, index, value):
"""Insert value at index of :attr:`states`.
Expand Down

0 comments on commit 2618c0b

Please sign in to comment.