-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathfield_matrix_solver.jl
489 lines (432 loc) · 20.1 KB
/
field_matrix_solver.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
"""
FieldMatrixSolverAlgorithm
Description of how to solve an equation of the form `A * x = b` for `x`, where
`A` is a `FieldMatrix` and where `x` and `b` are both `FieldVector`s. Different
algorithms can be nested inside each other, enabling the construction of
specialized linear solvers that fully utilize the sparsity pattern of `A`.
# Interface
Every subtype of `FieldMatrixSolverAlgorithm` must implement methods for the
following functions:
- [`field_matrix_solver_cache`](@ref)
- [`check_field_matrix_solver`](@ref)
- [`run_field_matrix_solver!`](@ref)
"""
abstract type FieldMatrixSolverAlgorithm end
"""
field_matrix_solver_cache(alg, A, b)
Allocates the cache required by the `FieldMatrixSolverAlgorithm` `alg` to solve
the equation `A * x = b`.
"""
function field_matrix_solver_cache end
"""
check_field_matrix_solver(alg, cache, A, b)
Checks that the sparsity structure of `A` is supported by the
`FieldMatrixSolverAlgorithm` `alg`, and that `A` is compatible with `b` in the
equation `A * x = b`.
"""
function check_field_matrix_solver end
"""
run_field_matrix_solver!(alg, cache, x, A, b)
Sets `x` to the value that solves the equation `A * x = b` using the
`FieldMatrixSolverAlgorithm` `alg`.
"""
function run_field_matrix_solver! end
"""
FieldMatrixSolver(alg, A, b)
Combination of a `FieldMatrixSolverAlgorithm` `alg` and the cache it requires to
solve the equation `A * x = b` for `x`. The values of `A` and `b` that get
passed to this constructor should be `similar` to the ones that get passed to
`field_matrix_solve!` in order to ensure that the cache gets allocated
correctly.
"""
struct FieldMatrixSolver{A <: FieldMatrixSolverAlgorithm, C}
alg::A
cache::C
end
function FieldMatrixSolver(
alg::FieldMatrixSolverAlgorithm,
A::FieldMatrix,
b::Fields.FieldVector,
)
b_view = field_vector_view(b)
A_with_tree = replace_name_tree(A, keys(b_view).name_tree)
cache = field_matrix_solver_cache(alg, A_with_tree, b_view)
check_field_matrix_solver(alg, cache, A_with_tree, b_view)
return FieldMatrixSolver(alg, cache)
end
"""
field_matrix_solve!(solver, x, A, b)
Solves the equation `A * x = b` for `x` using the `FieldMatrixSolver` `solver`.
"""
NVTX.@annotate function field_matrix_solve!(
solver::FieldMatrixSolver,
x::Fields.FieldVector,
A::FieldMatrix,
b::Fields.FieldVector,
)
(; alg, cache) = solver
x_view = field_vector_view(x)
b_view = field_vector_view(b)
keys(x_view) == keys(b_view) || error(
"The linear system cannot be solved because x and b have incompatible \
keys: $(set_string(keys(x_view))) vs. $(set_string(keys(b_view)))",
)
A_with_tree = replace_name_tree(A, keys(b_view).name_tree)
check_field_matrix_solver(alg, cache, A_with_tree, b_view)
run_field_matrix_solver!(alg, cache, x_view, A_with_tree, b_view)
return x
end
function check_block_diagonal_matrix_has_no_missing_blocks(A, b)
rows_with_missing_blocks =
setdiff(keys(b), matrix_row_keys(matrix_diagonal_keys(keys(A))))
missing_keys = corresponding_matrix_keys(rows_with_missing_blocks)
# The missing keys correspond to zeros, and det(A) = 0 when A is a block
# diagonal matrix with zeros along its diagonal. We can only solve A * x = b
# if det(A) != 0, so we throw an error whenever there are missing keys.
# Although it might still be the case that det(A) = 0 even if there are no
# missing keys, this cannot be inferred during compilation.
isempty(missing_keys) ||
error("The linear system cannot be solved because A does not have any \
entries at the following keys: $(set_string(missing_keys))")
end
function partition_blocks(names₁, A, b = nothing, x = nothing)
keys₁ = FieldVectorKeys(names₁, keys(A).name_tree)
keys₂ = set_complement(keys₁)
A₁₁ = A[cartesian_product(keys₁, keys₁)]
A₁₂ = A[cartesian_product(keys₁, keys₂)]
A₂₁ = A[cartesian_product(keys₂, keys₁)]
A₂₂ = A[cartesian_product(keys₂, keys₂)]
b_blocks = isnothing(b) ? () : (b[keys₁], b[keys₂])
x_blocks = isnothing(x) ? () : (x[keys₁], x[keys₂])
return (A₁₁, A₁₂, A₂₁, A₂₂, b_blocks..., x_blocks...)
end
function similar_to_x(A, b)
entries = map(matrix_row_keys(keys(A))) do name
similar(b[name], x_eltype(A[name, name], b[name]))
end
return FieldNameDict(matrix_row_keys(keys(A)), entries)
end
################################################################################
# Lazy (i.e., as matrix-free as possible) operations for FieldMatrix and
# analogues of FieldMatrix
lazy_inv(A) = Base.Broadcast.broadcasted(inv, A)
lazy_add(As...) = Base.Broadcast.broadcasted(+, As...)
lazy_sub(As...) = Base.Broadcast.broadcasted(-, As...)
"""
lazy_mul(A, args...)
Constructs a lazy `FieldMatrix` that represents the product `@. *(A, args...)`.
This involves regular broadcasting when `A` is a `FieldMatrix`, but it has more
complex behavior for other objects like the [`LazySchurComplement`](@ref).
"""
lazy_mul(A, args...) = Base.Broadcast.broadcasted(*, A, args...)
"""
LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂, [alg₁, cache₁, A₁₂_x₂, invA₁₁_A₁₂_x₂])
An analogue of a `FieldMatrix` that represents the Schur complement of `A₁₁` in
`A`, `A₂₂ - A₂₁ * inv(A₁₁) * A₁₂`. Since `inv(A₁₁)` will generally be a dense
matrix, it would not be efficient to directly compute the Schur complement. So,
this object only supports the "lazy" functions [`lazy_mul`](@ref), which allows
it to be multiplied by the vector `x₂`, and [`lazy_preconditioner`](@ref), which
allows it to be approximated with a `FieldMatrix`.
The values `alg₁`, `cache₁`, `A₁₂_x₂`, and `invA₁₁_A₁₂_x₂` need to be specified
in order for `lazy_mul` to be able to compute `inv(A₁₁) * A₁₂ * x₂`. When a
`LazySchurComplement` is not passed to `lazy_mul`, these values can be omitted.
"""
struct LazySchurComplement{M11, M12, M21, M22, A1, C1, V1, V2}
A₁₁::M11
A₁₂::M12
A₂₁::M21
A₂₂::M22
alg₁::A1
cache₁::C1
A₁₂_x₂::V1
invA₁₁_A₁₂_x₂::V2
end
LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂) =
LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂, nothing, nothing, nothing, nothing)
NVTX.@annotate function lazy_mul(A₂₂′::LazySchurComplement, x₂)
(; A₁₁, A₁₂, A₂₁, A₂₂, alg₁, cache₁, A₁₂_x₂, invA₁₁_A₁₂_x₂) = A₂₂′
zero_rows = setdiff(keys(A₁₂_x₂), matrix_row_keys(keys(A₁₂)))
@. A₁₂_x₂ = A₁₂ * x₂ + zero(A₁₂_x₂[zero_rows])
run_field_matrix_solver!(alg₁, cache₁, invA₁₁_A₁₂_x₂, A₁₁, A₁₂_x₂)
return lazy_sub(lazy_mul(A₂₂, x₂), lazy_mul(A₂₁, invA₁₁_A₁₂_x₂))
end
"""
LazyFieldMatrixSolverAlgorithm
A `FieldMatrixSolverAlgorithm` that does not require `A` to be a `FieldMatrix`,
i.e., a "matrix-free" algorithm. Internally, a `FieldMatrixSolverAlgorithm`
(for example, [`SchurComplementReductionSolve`](@ref)) might run a
`LazyFieldMatrixSolverAlgorithm` on a "lazy" representation of a `FieldMatrix`
(like a [`LazySchurComplement`](@ref)).
The only operations used by a `LazyFieldMatrixSolverAlgorithm` that depend on
`A` are [`lazy_mul`](@ref) and, when required, [`lazy_preconditioner`](@ref).
These and other lazy operations are used to minimize the number of calls to
`Base.materialize!`, since each call comes with a small performance penalty.
"""
abstract type LazyFieldMatrixSolverAlgorithm <: FieldMatrixSolverAlgorithm end
################################################################################
"""
BlockDiagonalSolve()
A `FieldMatrixSolverAlgorithm` for a block diagonal matrix:
```math
A = \\begin{bmatrix}
A_{11} & \\mathbf{0} & \\mathbf{0} & \\cdots & \\mathbf{0} \\\\
\\mathbf{0} & A_{22} & \\mathbf{0} & \\cdots & \\mathbf{0} \\\\
\\mathbf{0} & \\mathbf{0} & A_{33} & \\cdots & \\mathbf{0} \\\\
\\vdots & \\vdots & \\vdots & \\ddots & \\vdots \\\\
\\mathbf{0} & \\mathbf{0} & \\mathbf{0} & \\cdots & A_{NN}
\\end{bmatrix}
```
This algorithm solves the `N` block equations `Aₙₙ * xₙ = bₙ` in sequence (though
we might want to parallelize it in the future).
If `Aₙₙ` is a diagonal matrix, the equation `Aₙₙ * xₙ = bₙ` is solved by making a
single pass over the data, setting each `xₙ[i]` to `inv(Aₙₙ[i, i]) * bₙ[i]`.
Otherwise, the equation `Aₙₙ * xₙ = bₙ` is solved using Gaussian elimination
(without pivoting), which makes two passes over the data. This is currently only
implemented for tri-diagonal and penta-diagonal matrices `Aₙₙ`. In Gaussian
elimination, `Aₙₙ` is effectively factorized into the product `Lₙ * Dₙ * Uₙ`,
where `Dₙ` is a diagonal matrix, and where `Lₙ` and `Uₙ` are unit lower and upper
triangular matrices, respectively. The first pass multiplies both sides of the
equation by `inv(Lₙ * Dₙ)`, replacing `Aₙₙ` with `Uₙ` and `bₙ` with `Uₙxₙ`, which
is referred to as putting `Aₙₙ` into "reduced row echelon form". The second pass
solves `Uₙ * xₙ = Uₙxₙ` for `xₙ` with a unit upper triangular matrix solver, which
is referred to as "back substitution". These operations can become numerically
unstable when `Aₙₙ` has entries with large disparities in magnitude, but avoiding
this would require swapping the rows of `Aₙₙ` (i.e., replacing `Dₙ` with a
partial pivoting matrix).
"""
struct BlockDiagonalSolve <: FieldMatrixSolverAlgorithm end
function field_matrix_solver_cache(::BlockDiagonalSolve, A, b)
caches = map(matrix_row_keys(keys(A))) do name
single_field_solver_cache(A[name, name], b[name])
end
return FieldNameDict(matrix_row_keys(keys(A)), caches)
end
function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b)
check_block_diagonal_matrix(
A,
"BlockDiagonalSolve cannot be used because A",
)
check_block_diagonal_matrix_has_no_missing_blocks(A, b)
foreach(matrix_row_keys(keys(A))) do name
check_single_field_solver(A[name, name], b[name])
end
end
cheap_inv(_) = false
cheap_inv(::UniformScaling) = true
cheap_inv(A::ColumnwiseBandMatrixField) = eltype(A) <: DiagonalMatrixRow
NVTX.@annotate function run_field_matrix_solver!(
::BlockDiagonalSolve,
cache,
x,
A,
b,
)
names = matrix_row_keys(keys(A))
# The following is a performance optimization.
# Using `foreach(name-> single_field_solve!(cache[name], x[name], A[name, name], b[name]), names)`
# is perfectly fine, but may launch many gpu kernels. So,
# We may want to call `multiple_field_solve!`, which fuses
# these kernels into one. However, `multiple_field_solve!`
# launches threads horizontally, and loops vertically (which
# is slow) to perform the solve. In some circumstances,
# when a vertical loop is not needed (e.g., UniformScaling)
# launching several kernels may be cheaper than launching one
# slower kernel, so we first check for types that may lead to fast
# kernels.
case1 = length(names) == 1
case2 = all(name -> cheap_inv(A[name, name]), names.values)
case3 = any(name -> cheap_inv(A[name, name]), names.values)
# TODO: remove case3 and implement _single_field_solve_diag_matrix_row!
# in multiple_field_solve!
if case1 || case2 || case3
foreach(names) do name
single_field_solve!(cache[name], x[name], A[name, name], b[name])
end
else
multiple_field_solve!(cache, x, A, b)
end
return nothing
end
"""
BlockLowerTriangularSolve(names₁...; [alg₁], [alg₂])
A `FieldMatrixSolverAlgorithm` for a 2×2 block lower triangular matrix:
```math
A = \\begin{bmatrix} A_{11} & \\mathbf{0} \\\\ A_{21} & A_{22} \\end{bmatrix}
```
The `FieldName`s in `names₁` correspond to the subscript `₁`, while all other
`FieldName`s correspond to the subscript `₂`. This algorithm has 2 steps:
1. Solve `A₁₁ * x₁ = b₁` for `x₁` using the algorithm `alg₁`, which is set to a
[`BlockDiagonalSolve`](@ref) by default.
2. Solve `A₂₂ * x₂ = b₂ - A₂₁ * x₁` for `x₂` using the algorithm `alg₂`, which
is also set to a `BlockDiagonalSolve` by default.
"""
struct BlockLowerTriangularSolve{
N <: NTuple{<:Any, FieldName},
A1 <: FieldMatrixSolverAlgorithm,
A2 <: FieldMatrixSolverAlgorithm,
} <: FieldMatrixSolverAlgorithm
names₁::N
alg₁::A1
alg₂::A2
end
BlockLowerTriangularSolve(
names₁...;
alg₁ = BlockDiagonalSolve(),
alg₂ = BlockDiagonalSolve(),
) = BlockLowerTriangularSolve(names₁, alg₁, alg₂)
function field_matrix_solver_cache(alg::BlockLowerTriangularSolve, A, b)
A₁₁, _, _, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b)
cache₁ = field_matrix_solver_cache(alg.alg₁, A₁₁, b₁)
b₂′ = similar(b₂)
cache₂ = field_matrix_solver_cache(alg.alg₂, A₂₂, b₂′)
return (; cache₁, b₂′, cache₂)
end
function check_field_matrix_solver(alg::BlockLowerTriangularSolve, cache, A, b)
A₁₁, A₁₂, _, A₂₂, b₁, _ = partition_blocks(alg.names₁, A, b)
isempty(keys(A₁₂)) || error(
"BlockLowerTriangularSolve cannot be used because A has entries at the \
following upper triangular keys: $(set_string(keys(A₁₂)))",
)
check_field_matrix_solver(alg.alg₁, cache.cache₁, A₁₁, b₁)
check_field_matrix_solver(alg.alg₂, cache.cache₂, A₂₂, cache.b₂′)
end
NVTX.@annotate function run_field_matrix_solver!(
alg::BlockLowerTriangularSolve,
cache,
x,
A,
b,
)
A₁₁, _, A₂₁, A₂₂, b₁, b₂, x₁, x₂ = partition_blocks(alg.names₁, A, b, x)
run_field_matrix_solver!(alg.alg₁, cache.cache₁, x₁, A₁₁, b₁)
@. cache.b₂′ = b₂ - A₂₁ * x₁
run_field_matrix_solver!(alg.alg₂, cache.cache₂, x₂, A₂₂, cache.b₂′)
end
"""
BlockArrowheadSolve(names₁...; [alg₂])
A `FieldMatrixSolverAlgorithm` for a 2×2 block arrowhead matrix:
```math
A = \\begin{bmatrix} A_{11} & A_{12} \\\\ A_{21} & A_{22} \\end{bmatrix}, \\quad
\\text{where } A_{11} \\text{ is a diagonal matrix}
```
The `FieldName`s in `names₁` correspond to the subscript `₁`, while all other
`FieldName`s correspond to the subscript `₂`. This algorithm has only 1 step:
1. Solve `(A₂₂ - A₂₁ * inv(A₁₁) * A₁₂) * x₂ = b₂ - A₂₁ * inv(A₁₁) * b₁` for `x₂`
using the algorithm `alg₂`, which is set to a [`BlockDiagonalSolve`](@ref) by
default, and set `x₁` to `inv(A₁₁) * (b₁ - A₁₂ * x₂)`.
Since `A₁₁` is a diagonal matrix, `inv(A₁₁)` is easy to compute, which means
that the Schur complement of `A₁₁` in `A`, `A₂₂ - A₂₁ * inv(A₁₁) * A₁₂`, as well
as the vectors `b₂ - A₂₁ * inv(A₁₁) * b₁` and `inv(A₁₁) * (b₁ - A₁₂ * x₂)`, are
also easy to compute.
This algorithm is equivalent to block Gaussian elimination with all operations
inlined into a single step.
"""
struct BlockArrowheadSolve{
N <: NTuple{<:Any, FieldName},
A <: FieldMatrixSolverAlgorithm,
} <: FieldMatrixSolverAlgorithm
names₁::N
alg₂::A
end
BlockArrowheadSolve(names₁...; alg₂ = BlockDiagonalSolve()) =
BlockArrowheadSolve(names₁, alg₂)
function field_matrix_solver_cache(alg::BlockArrowheadSolve, A, b)
A₁₁, A₁₂, A₂₁, A₂₂, _, b₂ = partition_blocks(alg.names₁, A, b)
A₂₂′ = @. A₂₂ - A₂₁ * inv(A₁₁) * A₁₂
b₂′ = similar(b₂)
cache₂ = field_matrix_solver_cache(alg.alg₂, A₂₂′, b₂′)
return (; A₂₂′, b₂′, cache₂)
end
function check_field_matrix_solver(alg::BlockArrowheadSolve, cache, A, b)
A₁₁, _, _, _, b₁, _ = partition_blocks(alg.names₁, A, b)
check_diagonal_matrix(A₁₁, "BlockArrowheadSolve cannot be used because A")
check_block_diagonal_matrix_has_no_missing_blocks(A₁₁, b₁)
check_field_matrix_solver(alg.alg₂, cache.cache₂, cache.A₂₂′, cache.b₂′)
end
NVTX.@annotate function run_field_matrix_solver!(
alg::BlockArrowheadSolve,
cache,
x,
A,
b,
)
A₁₁, A₁₂, A₂₁, A₂₂, b₁, b₂, x₁, x₂ = partition_blocks(alg.names₁, A, b, x)
@. cache.A₂₂′ = A₂₂ - A₂₁ * inv(A₁₁) * A₁₂
@. cache.b₂′ = b₂ - A₂₁ * inv(A₁₁) * b₁
run_field_matrix_solver!(alg.alg₂, cache.cache₂, x₂, cache.A₂₂′, cache.b₂′)
@. x₁ = inv(A₁₁) * (b₁ - A₁₂ * x₂)
end
"""
SchurComplementReductionSolve(names₁...; [alg₁], alg₂)
A `FieldMatrixSolverAlgorithm` for any 2×2 block matrix:
```math
A = \\begin{bmatrix} A_{11} & A_{12} \\\\ A_{21} & A_{22} \\end{bmatrix}
```
The `FieldName`s in `names₁` correspond to the subscript `₁`, while all other
`FieldName`s correspond to the subscript `₂`. This algorithm has 3 steps:
1. Solve `A₁₁ * x₁′ = b₁` for `x₁′` using the algorithm `alg₁`, which is set to
a [`BlockDiagonalSolve`](@ref) by default.
2. Solve `(A₂₂ - A₂₁ * inv(A₁₁) * A₁₂) * x₂ = b₂ - A₂₁ * x₁′` for `x₂`
using the algorithm `alg₂`.
3. Solve `A₁₁ * x₁ = b₁ - A₁₂ * x₂` for `x₁` using the algorithm `alg₁`.
Since `A₁₁` is not necessarily a diagonal matrix, `inv(A₁₁)` will generally be a
dense matrix, which means that the Schur complement of `A₁₁` in `A`,
`A₂₂ - A₂₁ * inv(A₁₁) * A₁₂`, cannot be computed efficiently. So, `alg₂` must be
set to a `LazyFieldMatrixSolverAlgorithm`, which can evaluate the matrix-vector
product `(A₂₂ - A₂₁ * inv(A₁₁) * A₁₂) * x₂` without actually computing the Schur
complement matrix. This involves representing the Schur complement matrix by a
[`LazySchurComplement`](@ref), which uses `alg₁` to invert `A₁₁` when computing
the matrix-vector product.
This algorithm is equivalent to block Gaussian elimination, where steps 1 and 2
put `A` into reduced row echelon form, and step 3 performs back substitution.
For more information on this algorithm, see Section 5 of [Numerical solution of
saddle point problems](@cite Benzi2005).
"""
struct SchurComplementReductionSolve{
N <: NTuple{<:Any, FieldName},
A1 <: FieldMatrixSolverAlgorithm,
A2 <: LazyFieldMatrixSolverAlgorithm,
} <: FieldMatrixSolverAlgorithm
names₁::N
alg₁::A1
alg₂::A2
end
SchurComplementReductionSolve(names₁...; alg₁ = BlockDiagonalSolve(), alg₂) =
SchurComplementReductionSolve(names₁, alg₁, alg₂)
function field_matrix_solver_cache(alg::SchurComplementReductionSolve, A, b)
A₁₁, A₁₂, A₂₁, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b)
b₁′ = similar(b₁)
cache₁ = field_matrix_solver_cache(alg.alg₁, A₁₁, b₁)
A₂₂′ = LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂)
b₂′ = similar(b₂)
cache₂ = field_matrix_solver_cache(alg.alg₂, A₂₂′, b₂′)
return (; b₁′, cache₁, b₂′, cache₂)
end
function check_field_matrix_solver(
alg::SchurComplementReductionSolve,
cache,
A,
b,
)
A₁₁, A₁₂, A₂₁, A₂₂, b₁, _ = partition_blocks(alg.names₁, A, b)
check_field_matrix_solver(alg.alg₁, cache.cache₁, A₁₁, b₁)
A₂₂′ = LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂)
check_field_matrix_solver(alg.alg₂, cache.cache₂, A₂₂′, cache.b₂′)
end
NVTX.@annotate function run_field_matrix_solver!(
alg::SchurComplementReductionSolve,
cache,
x,
A,
b,
)
A₁₁, A₁₂, A₂₁, A₂₂, b₁, b₂, x₁, x₂ = partition_blocks(alg.names₁, A, b, x)
x₁′ = x₁ # Use x₁ as temporary storage to avoid additional allocations.
schur_complement_args = (alg.alg₁, cache.cache₁, cache.b₁′, x₁′)
A₂₂′ = LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂, schur_complement_args...)
run_field_matrix_solver!(alg.alg₁, cache.cache₁, x₁′, A₁₁, b₁)
@. cache.b₂′ = b₂ - A₂₁ * x₁′
run_field_matrix_solver!(alg.alg₂, cache.cache₂, x₂, A₂₂′, cache.b₂′)
@. cache.b₁′ = b₁ - A₁₂ * x₂
run_field_matrix_solver!(alg.alg₁, cache.cache₁, x₁, A₁₁, cache.b₁′)
end