-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPoseEstimator.java
182 lines (156 loc) · 7.03 KB
/
PoseEstimator.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// Copyright (c) 2023 FRC 6328
// http://github.com/Mechanical-Advantage
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file at
// the root directory of this project.
package frc.robot.commons;
import edu.wpi.first.math.Matrix;
import edu.wpi.first.math.Nat;
import edu.wpi.first.math.VecBuilder;
import edu.wpi.first.math.geometry.Pose2d;
import edu.wpi.first.math.geometry.Twist2d;
import edu.wpi.first.math.numbers.N1;
import edu.wpi.first.math.numbers.N3;
import edu.wpi.first.wpilibj.Timer;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.NavigableMap;
import java.util.TreeMap;
import org.littletonrobotics.junction.Logger;
public class PoseEstimator {
private static final double historyLengthSecs = 0.3;
private Pose2d basePose = new Pose2d();
private Pose2d latestPose = new Pose2d();
private final NavigableMap<Double, PoseUpdate> updates = new TreeMap<>();
private Matrix<N3, N1> q = new Matrix<>(Nat.N3(), Nat.N1());
public PoseEstimator(Matrix<N3, N1> stateStdDevs) {
for (int i = 0; i < 3; ++i) {
q.set(i, 0, stateStdDevs.get(i, 0) * stateStdDevs.get(i, 0));
}
}
/** Returns the latest robot pose based on drive and vision data. */
public Pose2d getLatestPose() {
return latestPose;
}
/** Resets the odometry to a known pose. */
public void resetPose(Pose2d pose) {
basePose = pose;
updates.clear();
update();
}
/** Sets the standard deviations of the model */
public void setModelStateStdDevs(Matrix<N3, N1> modelStateStdDevs) {
q = modelStateStdDevs;
}
/** Records a new drive movement. */
public void addDriveData(double timestamp, Twist2d twist) {
updates.put(timestamp, new PoseUpdate(twist, new ArrayList<>()));
update();
}
/** Records a new set of vision updates. */
public void addVisionData(List<TimestampedVisionUpdate> visionData) {
for (var timestampedVisionUpdate : visionData) {
var timestamp = timestampedVisionUpdate.timestamp();
var visionUpdate = new VisionUpdate(timestampedVisionUpdate.pose(), timestampedVisionUpdate.stdDevs());
if (updates.containsKey(timestamp)) {
// There was already an update at this timestamp, add to it
var oldVisionUpdates = updates.get(timestamp).visionUpdates();
oldVisionUpdates.add(visionUpdate);
oldVisionUpdates.sort(VisionUpdate.compareDescStdDev);
} else {
// Insert a new update
var prevUpdate = updates.floorEntry(timestamp);
var nextUpdate = updates.ceilingEntry(timestamp);
if (prevUpdate == null || nextUpdate == null) {
// Outside the range of existing data
return;
}
// Create partial twists (prev -> vision, vision -> next)
var twist0 = GeomUtil.multiplyTwist(
nextUpdate.getValue().twist(),
(timestamp - prevUpdate.getKey()) / (nextUpdate.getKey() - prevUpdate.getKey()));
var twist1 = GeomUtil.multiplyTwist(
nextUpdate.getValue().twist(),
(nextUpdate.getKey() - timestamp) / (nextUpdate.getKey() - prevUpdate.getKey()));
// Add new pose updates
var newVisionUpdates = new ArrayList<VisionUpdate>();
newVisionUpdates.add(visionUpdate);
newVisionUpdates.sort(VisionUpdate.compareDescStdDev);
updates.put(timestamp, new PoseUpdate(twist0, newVisionUpdates));
updates.put(
nextUpdate.getKey(), new PoseUpdate(twist1, nextUpdate.getValue().visionUpdates()));
}
}
// Recalculate latest pose once
update();
}
/** Clears old data and calculates the latest pose. */
private void update() {
// Clear old data and update base pose
while (updates.size() > 1
&& updates.firstKey() < Timer.getFPGATimestamp() - historyLengthSecs) {
var update = updates.pollFirstEntry();
basePose = update.getValue().apply(basePose, q);
}
// Update latest pose
latestPose = basePose;
for (var updateEntry : updates.entrySet()) {
latestPose = updateEntry.getValue().apply(latestPose, q);
}
Logger.getInstance().recordOutput("Odometry/RobotPosition", latestPose);
}
/**
* Represents a sequential update to a pose estimate, with a twist (drive
* movement) and list of
* vision updates.
*/
private static record PoseUpdate(Twist2d twist, ArrayList<VisionUpdate> visionUpdates) {
public Pose2d apply(Pose2d lastPose, Matrix<N3, N1> q) {
// Apply drive twist
var pose = lastPose.exp(twist);
// Apply vision updates
for (var visionUpdate : visionUpdates) {
// Calculate Kalman gains based on std devs
// (https://github.com/wpilibsuite/allwpilib/blob/main/wpimath/src/main/java/edu/wpi/first/math/estimator/)
Matrix<N3, N3> visionK = new Matrix<>(Nat.N3(), Nat.N3());
var r = new double[3];
for (int i = 0; i < 3; ++i) {
r[i] = visionUpdate.stdDevs().get(i, 0) * visionUpdate.stdDevs().get(i, 0);
}
for (int row = 0; row < 3; ++row) {
if (q.get(row, 0) == 0.0) {
visionK.set(row, row, 0.0);
} else {
visionK.set(
row, row, q.get(row, 0) / (q.get(row, 0) + Math.sqrt(q.get(row, 0) * r[row])));
}
}
// Calculate twist between current and vision pose
var visionTwist = pose.log(visionUpdate.pose());
// Multiply by Kalman gain matrix
var twistMatrix = visionK.times(VecBuilder.fill(visionTwist.dx, visionTwist.dy, visionTwist.dtheta));
// Apply twist
pose = pose.exp(
new Twist2d(twistMatrix.get(0, 0), twistMatrix.get(1, 0), twistMatrix.get(2, 0)));
}
return pose;
}
}
/** Represents a single vision pose with associated standard deviations. */
public static record VisionUpdate(Pose2d pose, Matrix<N3, N1> stdDevs) {
public static final Comparator<VisionUpdate> compareDescStdDev = (VisionUpdate a, VisionUpdate b) -> {
return -Double.compare(
a.stdDevs().get(0, 0) + a.stdDevs().get(1, 0),
b.stdDevs().get(0, 0) + b.stdDevs().get(1, 0));
};
}
/**
* Represents a single vision pose with a timestamp and associated standard
* deviations.
*/
public static record TimestampedVisionUpdate(
double timestamp, Pose2d pose, Matrix<N3, N1> stdDevs) {
}
}