-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdata_loaders.jl
70 lines (59 loc) · 2.27 KB
/
data_loaders.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
export process_mnist,
sampled_mnist,
twenty_datasets,
twenty_dataset_names
using CSV
using Pkg.Artifacts
using DataFrames
####################
# Constants
#####################
const twenty_dataset_names = [
"accidents", "ad", "baudio", "bbc", "bnetflix", "book", "c20ng", "cr52", "cwebkb",
"dna", "jester", "kdd", "kosarek", "msnbc", "msweb", "nltcs", "plants", "pumsb_star", "tmovie", "tretail",
"binarized_mnist"
];
#####################
# Data loaders
#####################
"""
Processes the mnist dataset using the MNIST object from MLDataSets package
`MLDS_MNIST` = the MNIST from MLDataSets
`labeled` = whether to return the lables
"""
function process_mnist(MLDS_MNIST, labeled = false)
# transposing makes slicing by variable much much faster
# need to take a copy to physically move the data around
train_x = collect(Float32, transpose(reshape(MLDS_MNIST.traintensor(), 28*28, :)))
test_x = collect(Float32, transpose(reshape(MLDS_MNIST.testtensor(), 28*28, :)))
train = DataFrame(train_x)
valid = nothing # why is there no validation set in `MLDataSets`??
test = DataFrame(test_x)
if (labeled)
train_y::Vector{UInt8} = MNIST.trainlabels()
test_y::Vector{UInt8} = MNIST.testlabels()
train.y = train_y
test.y = test_y
end
return train, valid, test
end
sampled_mnist() = twenty_datasets("binarized_mnist")
"""
train, valid, test = twenty_datasets(name)
Load a given dataset from the density estimation datasets. Automatically downloads the files as julia Artifacts.
See https://github.com/UCLA-StarAI/Density-Estimation-Datasets for a list of avaialble datasets.
"""
function twenty_datasets(name)
@assert in(name, twenty_dataset_names)
data_dir = artifact"density_estimation_datasets"
function load(type)
dataframe = CSV.read(data_dir*"/Density-Estimation-Datasets-1.0.1/datasets/$name/$name.$type.data", DataFrame;
header=false, truestrings=["1"], falsestrings=["0"], type=Bool, strict=true)
# make sure the data is backed by a `BitArray`
DataFrame((BitArray(Base.convert(Matrix{Bool}, dataframe))))
end
train = load("train")
valid = load("valid")
test = load("test")
return train, valid, test
end