Skip to content

Commit

Permalink
core: add a class to concatenate envelope without deep copy
Browse files Browse the repository at this point in the history
  • Loading branch information
eckter committed Dec 21, 2023
1 parent ea45fc5 commit 181b557
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import fr.sncf.osrd.envelope.part.EnvelopePart;
import fr.sncf.osrd.reporting.exceptions.ErrorType;
import fr.sncf.osrd.reporting.exceptions.OSRDError;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.stream.Stream;

Expand Down Expand Up @@ -324,8 +326,25 @@ public EnvelopePart next() {
};
}

@Override
public List<EnvelopePoint> iteratePoints() {
var res = new ArrayList<EnvelopePoint>();
double time = 0;
for (var part : this) {
// Add head position points
for (int i = 0; i < part.pointCount(); i++) {
var pos = part.getPointPos(i);
var speed = part.getPointSpeed(i);
res.add(new EnvelopePoint(time, speed, pos));
if (i < part.stepCount())
time += part.getStepTime(i);
}
}
return res;
}

/** Makes a stream from the parts */
public Stream<EnvelopePart> stream() {
return Arrays.stream(parts);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package fr.sncf.osrd.envelope;

import static fr.sncf.osrd.envelope_sim.TrainPhysicsIntegrator.arePositionsEqual;

import java.util.ArrayList;
import java.util.List;

/**
* This class is used to concatenate envelopes without a deep copy of all the underlying data.
* All envelopes are expected to start at position 0.
*/
public class EnvelopeConcat implements EnvelopeTimeInterpolate {

private final List<LocatedEnvelope> envelopes;
private final double endPos;

private EnvelopeConcat(List<LocatedEnvelope> envelopes, double endPos) {
this.envelopes = envelopes;
this.endPos = endPos;
}

/**
* Creates an instance from a list of envelopes
*/
public static EnvelopeConcat from(List<? extends EnvelopeTimeInterpolate> envelopes) {
runSanityChecks(envelopes);
var locatedEnvelopes = initLocatedEnvelopes(envelopes);
var lastEnvelope = locatedEnvelopes.get(locatedEnvelopes.size() - 1);
var endPos = lastEnvelope.startOffset + lastEnvelope.envelope.getEndPos();
return new EnvelopeConcat(locatedEnvelopes, endPos);
}

/**
* Run some checks to ensure that the inputs match the assumptions made by this class
*/
private static void runSanityChecks(List<? extends EnvelopeTimeInterpolate> envelopes) {
assert !envelopes.isEmpty() : "concatenating no envelope";
for (var envelope : envelopes)
assert arePositionsEqual(0, envelope.getBeginPos()) : "concatenated envelope doesn't start at 0";
}

/**
* Place all envelopes in a record containing the offset on which they start
*/
private static List<LocatedEnvelope> initLocatedEnvelopes(List<? extends EnvelopeTimeInterpolate> envelopes) {
double currentOffset = 0.0;
double currentTime = 0.0;
var res = new ArrayList<LocatedEnvelope>();
for (var envelope : envelopes) {
res.add(new LocatedEnvelope(envelope, currentOffset, currentTime));
currentOffset += envelope.getEndPos();
currentTime += envelope.getTotalTime();
}
return res;
}

@Override
public double interpolateTotalTime(double position) {
var envelope = findEnvelopeAt(position);
assert envelope != null : "Trying to interpolate time outside of the envelope";
return envelope.startTime + envelope.envelope.interpolateTotalTime(position - envelope.startOffset);
}

@Override
public long interpolateTotalTimeMS(double position) {
return (long) (interpolateTotalTime(position) * 1000);
}

@Override
public double interpolateTotalTimeClamp(double position) {
var clamped = Math.max(0, Math.min(position, endPos));
return interpolateTotalTime(clamped);
}

@Override
public double getBeginPos() {
return 0;
}

@Override
public double getEndPos() {
return endPos;
}

@Override
public double getTotalTime() {
return interpolateTotalTime(endPos);
}

@Override
public List<EnvelopePoint> iteratePoints() {
return envelopes.stream()
.flatMap(locatedEnvelope ->
locatedEnvelope.envelope.iteratePoints()
.stream()
.map(p -> new EnvelopePoint(
p.time() + locatedEnvelope.startTime,
p.speed(),
p.position() + locatedEnvelope.startOffset
))
).toList();
}

/**
* Returns the envelope at the given position. On transitions, the rightmost envelope is returned.
*/
private LocatedEnvelope findEnvelopeAt(double position) {
if (position < 0)
return null;
for (var envelope : envelopes) {
if (position < envelope.startOffset + envelope.envelope.getEndPos())
return envelope;
}
var lastEnvelope = envelopes.get(envelopes.size() - 1);
if (arePositionsEqual(position, lastEnvelope.startOffset + lastEnvelope.envelope.getEndPos()))
return lastEnvelope;
return null;
}

private record LocatedEnvelope(
EnvelopeTimeInterpolate envelope,
double startOffset,
double startTime
) {
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package fr.sncf.osrd.envelope;

import java.util.List;

public interface EnvelopeTimeInterpolate {

/** Computes the time required to get to a given point of the envelope */
Expand All @@ -20,4 +22,8 @@ public interface EnvelopeTimeInterpolate {

/** Returns the total time of the envelope */
double getTotalTime();

record EnvelopePoint(double time, double speed, double position){}

List<EnvelopePoint> iteratePoints();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package fr.sncf.osrd.envelope;

import static org.junit.jupiter.api.Assertions.assertEquals;

import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.function.Function;

public class EnvelopeConcatTest {

@Test
public void testSingleEnvelope() {
var envelope = Envelope.make(
EnvelopeTestUtils.generateTimes(new double[]{0, 1}, new double[]{1, 1}),
EnvelopeTestUtils.generateTimes(new double[]{1, 2}, new double[]{1, 1})
);
var concatenated = EnvelopeConcat.from(List.of(envelope));

// List of functions to call, they should return the same result for the envelope and the concatenated version
var functions = List.<Function<EnvelopeTimeInterpolate, Double>>of(
in -> in.interpolateTotalTime(0),
in -> in.interpolateTotalTime(1),
in -> in.interpolateTotalTime(2),
in -> (double) in.interpolateTotalTimeMS(1.5),
in -> in.interpolateTotalTimeClamp(-1),
in -> in.interpolateTotalTimeClamp(0.5),
EnvelopeTimeInterpolate::getBeginPos,
EnvelopeTimeInterpolate::getEndPos,
EnvelopeTimeInterpolate::getTotalTime
);

for (var f : functions)
assertEquals(f.apply(envelope), f.apply(concatenated));
assertEquals(envelope.iteratePoints(), concatenated.iteratePoints());
}

@Test
public void testTwoEnvelopes() {
final var envelopes = List.of(
Envelope.make(
EnvelopeTestUtils.generateTimes(new double[]{0, 1}, new double[]{1, 2}),
EnvelopeTestUtils.generateTimes(new double[]{1, 2}, new double[]{2, 3})
),
Envelope.make(
EnvelopeTestUtils.generateTimes(new double[]{0, 1}, new double[]{3, 4}),
EnvelopeTestUtils.generateTimes(new double[]{1, 2}, new double[]{4, 5})
)
);
final var concatenated = EnvelopeConcat.from(envelopes);
final var firstEnvelopeTime = envelopes.get(0).getTotalTime();
final var secondEnvelopeTime = envelopes.get(1).getTotalTime();

assertEquals(
firstEnvelopeTime + envelopes.get(1).interpolateTotalTime(1),
concatenated.interpolateTotalTime(3)
);
assertEquals(0, concatenated.getBeginPos());
assertEquals(4, concatenated.getEndPos());
assertEquals(firstEnvelopeTime + secondEnvelopeTime, concatenated.getTotalTime());

final var points = concatenated.iteratePoints();
final var firstPoint = points.get(0);
final var lastPoint = points.get(points.size() - 1);
assertEquals(0, firstPoint.time());
assertEquals(0, firstPoint.position());
assertEquals(1, firstPoint.speed());
assertEquals(firstEnvelopeTime, secondEnvelopeTime, lastPoint.time());
assertEquals(4, lastPoint.position());
assertEquals(5, lastPoint.speed());
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package fr.sncf.osrd.standalone_sim;

import fr.sncf.osrd.envelope.Envelope;
import static fr.sncf.osrd.envelope_sim.TrainPhysicsIntegrator.arePositionsEqual;

import fr.sncf.osrd.envelope.EnvelopeTimeInterpolate;
import fr.sncf.osrd.train.TrainStop;
import java.util.ArrayList;
import java.util.List;

public class EnvelopeStopWrapper implements EnvelopeTimeInterpolate {
public final Envelope envelope;
public final EnvelopeTimeInterpolate envelope;
public final List<TrainStop> stops;

public EnvelopeStopWrapper(Envelope envelope, List<TrainStop> stops) {
public EnvelopeStopWrapper(EnvelopeTimeInterpolate envelope, List<TrainStop> stops) {
this.envelope = envelope;
this.stops = stops;
}
Expand Down Expand Up @@ -51,33 +52,22 @@ public double getTotalTime() {
return envelope.getTotalTime() + stops.stream().mapToDouble(stop -> stop.duration).sum();
}

public record CurvePoint(double time, double speed, double position){}

/** Returns all the points as (time, speed, position), with time adjusted for stop duration */
public List<CurvePoint> iterateCurve() {
var res = new ArrayList<CurvePoint>();
double time = 0;
for (var part : envelope) {
// Add head position points
for (int i = 0; i < part.pointCount(); i++) {
var pos = part.getPointPos(i);
var speed = part.getPointSpeed(i);
res.add(new CurvePoint(time, speed, pos));
if (i < part.stepCount())
time += part.getStepTime(i);
}

if (part.getEndSpeed() > 0)
continue;

// Add stop duration
for (var stop : stops) {
if (stop.duration == 0. || stop.position < part.getEndPos())
continue;
if (stop.position > part.getEndPos())
break;
time += stop.duration;
res.add(new CurvePoint(time, 0, part.getEndPos()));
@Override
public List<EnvelopePoint> iteratePoints() {
var res = new ArrayList<EnvelopePoint>();
double sumPreviousStopDuration = 0;
int stopIndex = 0;
for (var point : envelope.iteratePoints()) {
var shiftedPoint = new EnvelopePoint(
point.time() + sumPreviousStopDuration,
point.speed(), point.position()
);
res.add(shiftedPoint);
if (stopIndex < stops.size() && arePositionsEqual(point.position(), stops.get(stopIndex).position)) {
var stopDuration = stops.get(stopIndex).duration;
stopIndex++;
sumPreviousStopDuration += stopDuration;
}
}
return res;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fun run(envelope: Envelope, trainPath: PathProperties, chunkPath: ChunkPath, sch
val trainLength = schedule.rollingStock.length
var speeds = ArrayList<ResultSpeed>()
var headPositions = ArrayList<ResultPosition>()
for (point in envelopeWithStops.iterateCurve()) {
for (point in envelopeWithStops.iteratePoints()) {
speeds.add(ResultSpeed(point.time, point.speed, point.position))
headPositions.add(ResultPosition.from(point.time, point.position, trainPath, rawInfra))
}
Expand Down

0 comments on commit 181b557

Please sign in to comment.