-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathset.ts
125 lines (114 loc) · 3.76 KB
/
set.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import { _internal_get_array_context } from "@zarrita/core";
import type { Array, Chunk, DataType, Scalar, TypedArray } from "@zarrita/core";
import type { Mutable } from "@zarrita/storage";
import { BasicIndexer, type IndexerProjection } from "./indexer.js";
import type {
Indices,
Prepare,
SetFromChunk,
SetOptions,
SetScalar,
Slice,
} from "./types.js";
import { create_queue } from "./util.js";
function flip_indexer_projection(m: IndexerProjection) {
if (m.to == null) return { from: m.to, to: m.from };
return { from: m.to, to: m.from };
}
export async function set<Dtype extends DataType, Arr extends Chunk<Dtype>>(
arr: Array<Dtype, Mutable>,
selection: (number | Slice | null)[] | null,
value: Scalar<Dtype> | Arr,
opts: SetOptions,
setter: {
prepare: Prepare<Dtype, Arr>;
set_scalar: SetScalar<Dtype, Arr>;
set_from_chunk: SetFromChunk<Dtype, Arr>;
},
) {
const context = _internal_get_array_context(arr);
if (context.kind === "sharded") {
throw new Error("Set not supported for sharded arrays.");
}
const indexer = new BasicIndexer({
selection,
shape: arr.shape,
chunk_shape: arr.chunks,
});
// We iterate over all chunks which overlap the selection and thus contain data
// that needs to be replaced. Each chunk is processed in turn, extracting the
// necessary data from the value array and storing into the chunk array.
const chunk_size = arr.chunks.reduce((a, b) => a * b, 1);
const queue = opts.create_queue ? opts.create_queue() : create_queue();
// N.B., it is an important optimisation that we only visit chunks which overlap
// the selection. This minimises the number of iterations in the main for loop.
for (const { chunk_coords, mapping } of indexer) {
const chunk_selection = mapping.map((i) => i.from);
const flipped = mapping.map(flip_indexer_projection);
queue.add(async () => {
// obtain key for chunk storage
const chunk_path = arr.resolve(
context.encode_chunk_key(chunk_coords),
).path;
let chunk_data: TypedArray<Dtype>;
const chunk_shape = arr.chunks.slice();
const chunk_stride = context.get_strides(chunk_shape);
if (is_total_slice(chunk_selection, chunk_shape)) {
// totally replace
chunk_data = new context.TypedArray(chunk_size);
// optimization: we are completely replacing the chunk, so no need
// to access the exisiting chunk data
if (typeof value === "object") {
// Otherwise data just contiguous TypedArray
const chunk = setter.prepare(
chunk_data,
chunk_shape.slice(),
chunk_stride.slice(),
);
// @ts-expect-error - Value is not a scalar
setter.set_from_chunk(chunk, value, flipped);
} else {
// @ts-expect-error - Value is a scalar
chunk_data.fill(value);
}
} else {
// partially replace the contents of this chunk
chunk_data = await arr.getChunk(chunk_coords).then(({ data }) => data);
const chunk = setter.prepare(
chunk_data,
chunk_shape.slice(),
chunk_stride.slice(),
);
// Modify chunk data
if (typeof value === "object") {
// @ts-expect-error - Value is not a scalar
setter.set_from_chunk(chunk, value, flipped);
} else {
setter.set_scalar(chunk, chunk_selection, value);
}
}
await arr.store.set(
chunk_path,
await context.codec.encode({
data: chunk_data,
shape: chunk_shape,
stride: chunk_stride,
}),
);
});
}
await queue.onIdle();
}
function is_total_slice(
selection: (number | Indices)[],
shape: readonly number[],
): selection is Indices[] {
// all items are Indices and every slice is complete
return selection.every((s, i) => {
// can't be a full selection
if (typeof s === "number") return false;
// explicit complete slice
const [start, stop, step] = s;
return stop - start === shape[i] && step === 1;
});
}