-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_benchmark_trials.m
76 lines (57 loc) · 1.92 KB
/
generate_benchmark_trials.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
% Script for generating trials (random goal and via points) for benchmark
%
% Author
% Sipu Ruan, 2023
close all; clear; clc;
add_paths()
dataset_name = {'panda_arm', 'lasa_handwriting/pose_data'};
n_trial = 50;
scale.mean = zeros(6,1);
scale.covariance = 1e-5;
% Factor for extrapolation
lambda_ex = 0;
% lambda_ex = 10;
for j = 1:length(dataset_name)
demo_types = load_dataset_param(dataset_name{j});
% Via point deviation based on dataset name
switch dataset_name{j}
case 'panda_arm'
scale.mean(1:3) = 1e-4 * ones(3,1) * (lambda_ex + 1);
scale.mean(4:6) = 1e-3 * rand(3,1) + lambda_ex;
case 'lasa_handwriting/pose_data'
scale.mean(4:6) = [1e-3 * rand(2,1) + lambda_ex; 0];
end
for i = 1:length(demo_types)
generate_trials(dataset_name{j}, demo_types{i}, n_trial, scale, false);
end
end
%% Function for generating benchmark trials
function generate_trials(dataset_name, demo_type, n_trial, scale, isplot)
clc;
disp(['Dataset: ', dataset_name])
disp(['Demo type: ', demo_type])
data_folder = strcat("../data/", dataset_name, "/", demo_type, "/");
result_folder = strcat("../result/benchmark/", dataset_name, "/", demo_type, "/");
mkdir(result_folder);
% Load and parse demo data
argin.n_step = 50;
argin.data_folder = data_folder;
argin.group_name = 'SE';
filenames = dir(strcat(argin.data_folder, "*.json"));
g_demo = parse_demo_trajectory(filenames, argin);
% Generate random goal/via points
id_demo = ceil(rand*length(g_demo));
t_via = [ones(n_trial, 1), rand(n_trial, 1)];
% t_via = [zeros(n_trial, 1), ones(n_trial, 1)];
trials = generate_random_trials(g_demo{id_demo}, t_via, scale, result_folder);
if isplot
% Plot trials
figure; hold on; axis equal;
for i = 1:length(trials.g_via)
g_via = trials.g_via{i};
for j = 1:size(g_via, 3)
plot3(g_via(1,4,j), g_via(2,4,j), g_via(3,4,j), 'b*');
end
end
end
end