-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrefine_dataset
52 lines (29 loc) · 1.18 KB
/
refine_dataset
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Nov 5 13:02:18 2022
@author: fa19
"""
import numpy as np
dataset = 'scan_age'
full = np.load(f'data/{dataset}/full.npy', allow_pickle=True)
train = np.load(f'data/{dataset}/train.npy', allow_pickle=True)
val = np.load(f'data/{dataset}/validation.npy', allow_pickle=True)
test = np.load(f'data/{dataset}/test.npy', allow_pickle=True)
everything = np.load('data/full/full.npy',allow_pickle=True)
all_ids = everything[:,0]
full_ = set(full[:,0])
full2 = [i for i,row in enumerate(everything) if row[0] in full_]
full = everything[full2]
full2 = full[[i for i,row in enumerate(full) if row[1]-row[2]<1]]
all_ids = everything[:,0]
train_ = set(train[:,0])
train2 = full2[[i for i,row in enumerate(full2) if row[0] in train_]]
val_ = set(val[:,0])
val2 =full2[[i for i,row in enumerate(full2) if row[0] in val_]]
test_ = set(test[:,0])
test2 = full2[[i for i,row in enumerate(full2) if row[0] in test_]]
np.save('data/scan_age_equal/full.npy', full2[:,:2])
np.save('data/scan_age_equal/train.npy', train2[:,:2])
np.save('data/scan_age_equal/validation.npy', val2[:,:2])
np.save('data/scan_age_equal/test.npy', test2[:,:2])