-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
core: add a class to concatenate envelope without deep copy
- Loading branch information
Showing
6 changed files
with
243 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions
126
core/envelope-sim/src/main/java/fr/sncf/osrd/envelope/EnvelopeConcat.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
package fr.sncf.osrd.envelope; | ||
|
||
import static fr.sncf.osrd.envelope_sim.TrainPhysicsIntegrator.arePositionsEqual; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
/** | ||
* This class is used to concatenate envelopes without a deep copy of all the underlying data. | ||
* All envelopes are expected to start at position 0. | ||
*/ | ||
public class EnvelopeConcat implements EnvelopeTimeInterpolate { | ||
|
||
private final List<LocatedEnvelope> envelopes; | ||
private final double endPos; | ||
|
||
private EnvelopeConcat(List<LocatedEnvelope> envelopes, double endPos) { | ||
this.envelopes = envelopes; | ||
this.endPos = endPos; | ||
} | ||
|
||
/** | ||
* Creates an instance from a list of envelopes | ||
*/ | ||
public static EnvelopeConcat from(List<? extends EnvelopeTimeInterpolate> envelopes) { | ||
runSanityChecks(envelopes); | ||
var locatedEnvelopes = initLocatedEnvelopes(envelopes); | ||
var lastEnvelope = locatedEnvelopes.get(locatedEnvelopes.size() - 1); | ||
var endPos = lastEnvelope.startOffset + lastEnvelope.envelope.getEndPos(); | ||
return new EnvelopeConcat(locatedEnvelopes, endPos); | ||
} | ||
|
||
/** | ||
* Run some checks to ensure that the inputs match the assumptions made by this class | ||
*/ | ||
private static void runSanityChecks(List<? extends EnvelopeTimeInterpolate> envelopes) { | ||
assert !envelopes.isEmpty() : "concatenating no envelope"; | ||
for (var envelope : envelopes) | ||
assert arePositionsEqual(0, envelope.getBeginPos()) : "concatenated envelope doesn't start at 0"; | ||
} | ||
|
||
/** | ||
* Place all envelopes in a record containing the offset on which they start | ||
*/ | ||
private static List<LocatedEnvelope> initLocatedEnvelopes(List<? extends EnvelopeTimeInterpolate> envelopes) { | ||
double currentOffset = 0.0; | ||
double currentTime = 0.0; | ||
var res = new ArrayList<LocatedEnvelope>(); | ||
for (var envelope : envelopes) { | ||
res.add(new LocatedEnvelope(envelope, currentOffset, currentTime)); | ||
currentOffset += envelope.getEndPos(); | ||
currentTime += envelope.getTotalTime(); | ||
} | ||
return res; | ||
} | ||
|
||
@Override | ||
public double interpolateTotalTime(double position) { | ||
var envelope = findEnvelopeAt(position); | ||
assert envelope != null : "Trying to interpolate time outside of the envelope"; | ||
return envelope.startTime + envelope.envelope.interpolateTotalTime(position - envelope.startOffset); | ||
} | ||
|
||
@Override | ||
public long interpolateTotalTimeMS(double position) { | ||
return (long) (interpolateTotalTime(position) * 1000); | ||
} | ||
|
||
@Override | ||
public double interpolateTotalTimeClamp(double position) { | ||
var clamped = Math.max(0, Math.min(position, endPos)); | ||
return interpolateTotalTime(clamped); | ||
} | ||
|
||
@Override | ||
public double getBeginPos() { | ||
return 0; | ||
} | ||
|
||
@Override | ||
public double getEndPos() { | ||
return endPos; | ||
} | ||
|
||
@Override | ||
public double getTotalTime() { | ||
return interpolateTotalTime(endPos); | ||
} | ||
|
||
@Override | ||
public List<EnvelopePoint> iteratePoints() { | ||
return envelopes.stream() | ||
.flatMap(locatedEnvelope -> | ||
locatedEnvelope.envelope.iteratePoints() | ||
.stream() | ||
.map(p -> new EnvelopePoint( | ||
p.time() + locatedEnvelope.startTime, | ||
p.speed(), | ||
p.position() + locatedEnvelope.startOffset | ||
)) | ||
).toList(); | ||
} | ||
|
||
/** | ||
* Returns the envelope at the given position. On transitions, the rightmost envelope is returned. | ||
*/ | ||
private LocatedEnvelope findEnvelopeAt(double position) { | ||
if (position < 0) | ||
return null; | ||
for (var envelope : envelopes) { | ||
if (position < envelope.startOffset + envelope.envelope.getEndPos()) | ||
return envelope; | ||
} | ||
var lastEnvelope = envelopes.get(envelopes.size() - 1); | ||
if (arePositionsEqual(position, lastEnvelope.startOffset + lastEnvelope.envelope.getEndPos())) | ||
return lastEnvelope; | ||
return null; | ||
} | ||
|
||
private record LocatedEnvelope( | ||
EnvelopeTimeInterpolate envelope, | ||
double startOffset, | ||
double startTime | ||
) { | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
core/envelope-sim/src/test/java/fr/sncf/osrd/envelope/EnvelopeConcatTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package fr.sncf.osrd.envelope; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
import org.junit.jupiter.api.Test; | ||
import java.util.List; | ||
import java.util.function.Function; | ||
|
||
public class EnvelopeConcatTest { | ||
|
||
@Test | ||
public void testSingleEnvelope() { | ||
var envelope = Envelope.make( | ||
EnvelopeTestUtils.generateTimes(new double[]{0, 1}, new double[]{1, 1}), | ||
EnvelopeTestUtils.generateTimes(new double[]{1, 2}, new double[]{1, 1}) | ||
); | ||
var concatenated = EnvelopeConcat.from(List.of(envelope)); | ||
|
||
// List of functions to call, they should return the same result for the envelope and the concatenated version | ||
var functions = List.<Function<EnvelopeTimeInterpolate, Double>>of( | ||
in -> in.interpolateTotalTime(0), | ||
in -> in.interpolateTotalTime(1), | ||
in -> in.interpolateTotalTime(2), | ||
in -> (double) in.interpolateTotalTimeMS(1.5), | ||
in -> in.interpolateTotalTimeClamp(-1), | ||
in -> in.interpolateTotalTimeClamp(0.5), | ||
EnvelopeTimeInterpolate::getBeginPos, | ||
EnvelopeTimeInterpolate::getEndPos, | ||
EnvelopeTimeInterpolate::getTotalTime | ||
); | ||
|
||
for (var f : functions) | ||
assertEquals(f.apply(envelope), f.apply(concatenated)); | ||
assertEquals(envelope.iteratePoints(), concatenated.iteratePoints()); | ||
} | ||
|
||
@Test | ||
public void testTwoEnvelopes() { | ||
final var envelopes = List.of( | ||
Envelope.make( | ||
EnvelopeTestUtils.generateTimes(new double[]{0, 1}, new double[]{1, 2}), | ||
EnvelopeTestUtils.generateTimes(new double[]{1, 2}, new double[]{2, 3}) | ||
), | ||
Envelope.make( | ||
EnvelopeTestUtils.generateTimes(new double[]{0, 1}, new double[]{3, 4}), | ||
EnvelopeTestUtils.generateTimes(new double[]{1, 2}, new double[]{4, 5}) | ||
) | ||
); | ||
final var concatenated = EnvelopeConcat.from(envelopes); | ||
final var firstEnvelopeTime = envelopes.get(0).getTotalTime(); | ||
final var secondEnvelopeTime = envelopes.get(1).getTotalTime(); | ||
|
||
assertEquals( | ||
firstEnvelopeTime + envelopes.get(1).interpolateTotalTime(1), | ||
concatenated.interpolateTotalTime(3) | ||
); | ||
assertEquals(0, concatenated.getBeginPos()); | ||
assertEquals(4, concatenated.getEndPos()); | ||
assertEquals(firstEnvelopeTime + secondEnvelopeTime, concatenated.getTotalTime()); | ||
|
||
final var points = concatenated.iteratePoints(); | ||
final var firstPoint = points.get(0); | ||
final var lastPoint = points.get(points.size() - 1); | ||
assertEquals(0, firstPoint.time()); | ||
assertEquals(0, firstPoint.position()); | ||
assertEquals(1, firstPoint.speed()); | ||
assertEquals(firstEnvelopeTime, secondEnvelopeTime, lastPoint.time()); | ||
assertEquals(4, lastPoint.position()); | ||
assertEquals(5, lastPoint.speed()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters