-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathget_esnli_examples_for_human_annotation.py
275 lines (263 loc) · 17.6 KB
/
get_esnli_examples_for_human_annotation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# example:
# --------
# python3 get_esnli_examples_for_human_annotation.py \
# --corpus esnli \
# --corpus-path "../data/corpus/esnli/esnli_test.csv" \
# --output-path "../data/human-corpus/esnli.txt" \
# --nb-examples 2000
import argparse
import random
from pathlib import Path
from spec.dataset.corpora import available_corpora
def select_indexes_stratified(words, targets, sel_nb_examples):
test_index = [1930, 6099, 1624, 4021, 4346, 1316, 1763, 8792, 6287, 2193,
2475, 5442, 1838, 5801, 4289, 7849, 6371, 5301, 5094, 1976,
7804, 3900, 8646, 8533, 4232, 5984, 980, 8924, 2989, 7753,
7214, 7622, 5495, 8815, 8670, 2745, 7554, 6656, 525, 1546,
4440, 1949, 8684, 8741, 8910, 9007, 7789, 1997, 5390, 7705,
2631, 9631, 2863, 4014, 2079, 1, 6316, 3294, 4293, 1331, 2601,
2931, 8352, 5481, 8040, 6869, 7029, 8195, 769, 4268, 3223,
9696, 7241, 197, 8107, 2590, 2744, 2390, 6150, 8125, 4499,
4895, 7511, 1466, 3591, 8837, 7686, 5222, 966, 4324, 735,
1304, 4435, 1165, 7140, 6359, 8032, 4452, 3696, 8082, 9036,
6238, 412, 3054, 8574, 6806, 7253, 3584, 8225, 38, 9480, 821,
7062, 2839, 9370, 3665, 707, 6166, 9346, 9586, 4986, 3476,
6278, 7980, 2060, 2136, 4360, 4658, 1852, 6245, 6563, 1887,
2967, 2507, 4926, 9335, 638, 552, 3141, 7271, 6542, 4964,
5475, 3385, 7683, 1239, 4496, 714, 3186, 7694, 5453, 5618,
1291, 8790, 6764, 7164, 3738, 9339, 2027, 4171, 634, 974,
2823, 688, 3503, 8403, 9073, 968, 9735, 5311, 3466, 6262,
6799, 699, 9488, 2849, 1542, 4208, 7483, 5085, 1990, 419,
8960, 2396, 8253, 4356, 7280, 3266, 2156, 6327, 8808, 3511,
924, 2511, 2768, 6052, 3725, 6904, 7989, 7079, 4873, 9778,
329, 4621, 1708, 5724, 8162, 2573, 2026, 8958, 3490, 3972,
281, 8766, 2408, 1333, 3568, 3506, 1571, 4576, 4495, 137,
7555, 5401, 7833, 5047, 7266, 7687, 7245, 2878, 943, 8243,
1646, 7418, 2493, 6777, 7043, 6374, 8666, 7388, 7595, 8552,
4648, 984, 8407, 4497, 6506, 9089, 6132, 3638, 7509, 7543,
7124, 5291, 5534, 94, 8365, 3553, 9453, 1202, 7900, 3086,
9687, 8053, 3005, 7059, 1191, 9268, 7860, 9550, 4177, 5726,
7052, 5543, 7546, 7141, 440, 9062, 7419, 4178, 3224, 1296,
4637, 454, 5704, 158, 4218, 4153, 5988, 832, 6891, 8472, 6257,
5713, 5348, 6530, 6667, 7570, 3030, 3192, 4037, 3712, 9639,
8524, 9107, 7368, 9257, 3041, 1098, 4566, 824, 6737, 6168,
9649, 1681, 559, 4849, 2790, 6002, 2832, 2599, 6254, 4492,
8934, 5631, 6617, 7723, 860, 1243, 4805, 1906, 4887, 1069,
4330, 1167, 661, 4904, 417, 340, 2785, 6588, 4242, 3496, 9707,
9810, 827, 8164, 3769, 850, 6652, 98, 2421, 770, 481, 3863,
5971, 1585, 2443, 8307, 4806, 2564, 5951, 6565, 7115, 3300,
7576, 4286, 7491, 7421, 7981, 1816, 7486, 1757, 823, 4884,
8716, 7979, 4167, 3338, 458, 7494, 5490, 5084, 7285, 395,
5105, 3818, 950, 4142, 2695, 6825, 6233, 853, 2334, 2106,
6477, 4569, 6490, 3351, 1648, 5299, 3701, 1604, 3510, 6467,
686, 954, 42, 8873, 4406, 4114, 1000, 7239, 301, 600, 6313,
1640, 1267, 5644, 1150, 3453, 6704, 6981, 4636, 6893, 4910,
3015, 9146, 804, 5189, 6116, 2983, 1895, 1682, 854, 9582,
6783, 11, 7327, 9415, 891, 6738, 189, 4150, 1664, 7956, 6960,
7916, 5892, 5370, 7858, 7564, 830, 4186, 5977, 4976, 9683,
5463, 1555, 3600, 6456, 7099, 260, 5732, 955, 1494, 1040,
2538, 1229, 8795, 8050, 9460, 7357, 1240, 3174, 9207, 3093,
8428, 7215, 2371, 7112, 9102, 8938, 4632, 8538, 3744, 5782,
8592, 659, 9288, 2482, 5043, 9319, 614, 9776, 2051, 1931,
2717, 9178, 3310, 8006, 8335, 5237, 7548, 7490, 8905, 5877,
3056, 1390, 4815, 3911, 4853, 6288, 3507, 2786, 6018, 4493,
1394, 9185, 3766, 5854, 6396, 364, 7400, 8848, 6895, 791,
4249, 3726, 1306, 3099, 8728, 4311, 5420, 9561, 4211, 7617,
5922, 7541, 3343, 5116, 6544, 4544, 5902, 9794, 7041, 8130,
9566, 5917, 8169, 7679, 3335, 5449, 6303, 1060, 3342, 6949,
523, 2830, 876, 3432, 6585, 5923, 7311, 5658, 8378, 1462,
3925, 7715, 8445, 6089, 3228, 4611, 2824, 4752, 1058, 7721,
4225, 8609, 6677, 2950, 8074, 6780, 5966, 1601, 7314, 2425,
2000, 1308, 7121, 7863, 9143, 64, 6434, 7002, 8034, 4703,
2920, 5373, 3681, 7943, 249, 9678, 5776, 2719, 2356, 3188,
6019, 6748, 2116, 2489, 956, 7629, 4599, 374, 8007, 7010,
7125, 7022, 9113, 4662, 726, 4255, 2665, 3152, 3177, 6035,
504, 5519, 6819, 7102, 8226, 1577, 5895, 8434, 1798, 2141,
4448, 1761, 6570, 8715, 5073, 5342, 8175, 2593, 9131, 89,
1759, 8631, 2469, 4909, 2200, 5240, 1804, 2383, 9282, 3242,
7287, 3531, 3216, 3748, 9642, 5578, 4278, 5450, 7069, 7014,
9482, 3572, 7479, 1959, 2898, 4070, 5063, 4725, 4809, 6015,
6431, 8274, 8957, 6892, 390, 5989, 6413, 4400, 1120, 4954,
5079, 8705, 4530, 1625, 6635, 2502, 3763, 3028, 4892, 6833,
207, 3993, 4760, 3067, 2961, 5051, 2568, 6638, 8558, 7151,
6881, 3444, 1994, 1913, 5179, 9575, 988, 1484, 9142, 6028,
6612, 793, 154, 8135, 556, 1190, 4422, 9216, 8838, 8932, 1266,
9474, 9244, 7000, 4179, 1704, 1389, 8865, 294, 1015, 7783,
9226, 5567, 5759, 649, 4831, 3561, 5486, 1274, 9059, 8580,
362, 4000, 6988, 6342, 9028, 2911, 3672, 630, 1021, 8312,
8363, 1633, 8502, 1619, 2289, 2871, 6546, 7240, 9332, 1367,
7208, 5287, 5367, 2811, 5691, 2545, 2382, 1177, 1796, 6047,
2975, 5881, 9559, 6980, 6292, 6694, 8554, 4312, 5352, 7891,
7424, 60, 7648, 8787, 8477, 1531, 35, 2137, 2064, 8872, 703,
5745, 5915, 9755, 6269, 666, 130, 317, 306, 3930, 8155, 569,
6305, 7083, 6910, 4555, 5468, 140, 1540, 1828, 3713, 9556,
482, 2033, 2411, 8368, 3778, 7987, 5056, 995, 4443, 9167,
5319, 2393, 9168, 1063, 5870, 2180, 5238, 7996, 8915, 7232,
2556, 619, 5252, 2708, 1059, 1853, 6185, 7175, 6082, 8901,
6078, 1214, 4768, 7513, 5, 7621, 4430, 4980, 9367, 4090, 1679,
6731, 9473, 8208, 5293, 9812, 8824, 5560, 1045, 916, 8804,
1012, 2336, 8258, 6471, 3953, 8172, 3097, 1380, 1004, 8004,
9806, 345, 928, 4431, 1636, 7294, 2720, 6276, 3170, 6136,
5530, 7334, 3345, 2838, 302, 2032, 6549, 4089, 4243, 8571,
8527, 6106, 5232, 917, 2203, 2551, 1686, 9808, 9600, 6644,
5177, 1435, 920, 384, 9493, 1241, 4033, 2604, 2772, 4061,
8311, 2693, 2891, 5561, 8320, 1524, 3173, 9708, 2806, 6766,
1667, 391, 4158, 8984, 5423, 3124, 8209, 8230, 5619, 5811,
3179, 4777, 3814, 5052, 4616, 1065, 6183, 6645, 6212, 4971,
3797, 2797, 8005, 2105, 4018, 775, 5472, 6941, 9692, 6767,
1947, 1607, 2113, 6122, 8340, 8110, 6172, 6665, 4846, 2733,
1939, 3371, 8669, 4488, 4804, 3446, 5161, 1481, 2190, 5788,
8989, 8898, 3579, 7123, 224, 6484, 7339, 8945, 9579, 4432,
1678, 3719, 3619, 5975, 7777, 2205, 1028, 1265, 6295, 3635,
1845, 5182, 1342, 2814, 7526, 4433, 7878, 3581, 7790, 2501,
5779, 3273, 9225, 1865, 9495, 1980, 4399, 2889, 987, 9376,
3389, 1324, 7709, 3610, 2353, 3629, 9568, 3051, 8734, 8902,
3157, 3440, 1290, 3735, 3042, 606, 4300, 4227, 7901, 6265,
7446, 4784, 3420, 1205, 5142, 7385, 7158, 4642, 5997, 8772,
1735, 9658, 8070, 9086, 8658, 8044, 6439, 5126, 2479, 1734,
3225, 1841, 6228, 7992, 4419, 8133, 3890, 760, 5913, 7426,
4783, 235, 2875, 9703, 7519, 539, 9585, 9, 8073, 8150, 2082,
1645, 8207, 2923, 784, 8219, 4729, 4858, 8702, 5859, 3040,
4524, 5031, 6732, 2327, 8238, 1969, 778, 2472, 9467, 8518,
2852, 8002, 1301, 1581, 6882, 6514, 3061, 2560, 7027, 8232,
899, 5083, 8875, 2385, 5334, 7056, 6813, 4859, 236, 2541,
3374, 2143, 4835, 6600, 3840, 3850, 9377, 4548, 5699, 3430,
8245, 3364, 5596, 1525, 6302, 7138, 8322, 2243, 9197, 9224,
3081, 8777, 4388, 262, 6770, 21, 8041, 6787, 105, 6105, 1122,
7192, 7667, 9387, 4946, 5207, 986, 8342, 6031, 5101, 2466,
1802, 4732, 1854, 4820, 1066, 9689, 3327, 3391, 8529, 7717,
2102, 5275, 5104, 2504, 2831, 878, 5010, 4197, 7186, 6016,
933, 2050, 3598, 536, 645, 4188, 7867, 1622, 728, 4795, 7179,
918, 4164, 7587, 4234, 77, 9147, 1148, 1460, 5541, 202, 8717,
3654, 2049, 4008, 1249, 951, 4465, 6596, 3354, 9805, 1343,
846, 6619, 2120, 6154, 5136, 4473, 6493, 3452, 4203, 1312,
3155, 7392, 1365, 4358, 2636, 277, 3427, 7761, 2558, 2322,
1998, 2047, 3976, 1750, 9287, 9451, 9481, 6368, 9183, 7270,
4282, 6059, 8519, 5791, 5208, 6385, 2198, 5029, 1341, 7561,
1605, 3419, 3675, 8691, 2703, 9417, 825, 161, 8633, 338, 341,
6446, 3376, 9680, 4413, 8490, 3812, 4893, 8442, 2637, 7026,
7657, 4116, 7560, 3230, 3072, 7312, 1560, 6616, 3627, 5897,
7101, 3486, 4263, 1409, 1693, 3843, 9358, 5387, 7843, 2219,
5243, 3998, 7904, 8965, 2808, 2278, 1660, 7119, 9726, 7343,
9298, 9587, 3828, 1970, 9811, 7473, 4547, 8415, 2074, 1353,
4109, 1875, 5109, 5305, 7892, 6986, 7939, 2734, 2192, 8549,
8512, 5600, 8885, 8010, 1286, 7409, 9462, 1639, 9770, 4882,
4451, 1100, 5535, 2711, 9202, 7518, 6852, 963, 4199, 2969,
3676, 116, 2645, 8389, 3854, 4543, 6762, 7813, 9204, 7190,
698, 3709, 5752, 4748, 4612, 66, 4494, 1014, 6353, 2368, 2747,
6937, 6936, 9276, 5655, 4145, 8392, 9406, 3451, 9208, 5233,
7832, 5860, 20, 4146, 269, 1692, 8152, 8647, 2635, 3760, 1347,
7764, 6573, 2612, 8967, 6938, 225, 9769, 8818, 9261, 5767,
7404, 1523, 4736, 9316, 6807, 2774, 3556, 125, 2731, 7575,
2071, 8319, 6944, 2454, 9384, 3845, 9498, 8956, 1521, 5415,
7999, 6090, 1771, 9492, 5919, 3073, 9743, 863, 6527, 3714,
4124, 9091, 7227, 8577, 4425, 2869, 3596, 3009, 6541, 7970,
6609, 7351, 8433, 8096, 7869, 3324, 3436, 9670, 9041, 4673,
5068, 2712, 3906, 5949, 6376, 2866, 72, 9507, 3644, 2204,
5191, 6284, 2427, 2713, 851, 3848, 9151, 145, 2872, 7577,
1387, 8435, 4206, 6675, 667, 54, 6420, 574, 575, 8120, 9033,
9289, 9802, 3683, 6235, 8567, 1733, 9555, 7196, 5588, 9432,
153, 7274, 3741, 8855, 2562, 6971, 7644, 3373, 632, 5760,
9779, 4790, 2125, 571, 3117, 6540, 512, 7326, 6809, 1712,
7098, 3004, 5750, 4644, 1007, 5445, 2661, 5320, 1762, 2739,
768, 7906, 1899, 5626, 1218, 1198, 1125, 4380, 2736, 6970,
1792, 4470, 8356, 5864, 6668, 8154, 2946, 5855, 8383, 7806,
8213, 7827, 9064, 8056, 4585, 961, 7330, 9139, 5280, 5565,
1248, 6430, 5115, 8306, 6776, 9676, 7982, 4973, 7114, 8194,
285, 13, 8496, 9485, 2062, 7668, 2343, 1707, 7147, 5343, 3249,
5742, 1408, 953, 8767, 5734, 1252, 529, 3304, 6924, 8308,
2401, 148, 2740, 8144, 363, 3913, 406, 7574, 4811, 4426, 3762,
2958, 5346, 2738, 1077, 1908, 580, 9478, 4079, 8921, 717,
1742, 9312, 255, 8058, 3330, 526, 6791, 5921, 635, 4670, 3974,
9194, 190, 126, 4, 6804, 1154, 1843, 8190, 4839, 922, 6898,
6674, 6873, 206, 6293, 2667, 1986, 5408, 4389, 8867, 7822,
7693, 2938, 7009, 117, 3883, 9539, 8755, 1386, 4799, 3636,
3356, 1810, 9249, 8849, 6722, 1663, 5188, 3865, 2265, 7291,
4757, 5622, 5134, 213, 7035, 8996, 2557, 6320, 3368, 9217,
5645, 8861, 8800, 9442, 583, 7673, 4565, 3022, 7641, 466,
5016, 5797, 3239, 1687, 7674, 4131, 2043, 2844, 5256, 2491,
7730, 8065, 3602, 3046, 4297, 1475, 671, 3957, 4133, 1072,
1823, 2162, 1925, 6943, 3775, 4876, 6024, 1591, 5757, 5986,
8038, 6640, 6801, 3530, 4737, 3267, 2674, 2779, 7350, 2752,
5604, 9821, 3640, 5569, 629, 2959, 4107, 6736, 6012, 1766,
4223, 4675, 6709, 2433, 9117, 5585, 3564, 3291, 936, 5150,
6102, 3549, 4445, 5575, 3178, 7057, 2782, 4788, 4363, 6830,
4551, 663, 4212, 232, 9203, 9338, 3695, 806, 9122, 9675, 6038,
7920, 3756, 6331, 5172, 2349, 55, 9547, 959, 1685, 1584, 3594,
1500, 545, 9233, 2255, 7734, 9103, 4780, 1136, 5393, 9169,
1276, 74, 1440, 7437, 2139, 8228, 1637, 8220, 5203, 4869, 404,
7184, 5954, 3462, 7643, 9435, 9270, 5124, 1184, 5941, 7455,
3147, 4261, 6236, 8393, 7830, 3416, 892, 4841, 3687, 1927,
6917, 7221, 7004, 5249, 1914, 2065, 9318, 4491, 4589, 1709,
4911, 1388, 2048, 3987, 2403, 4036, 6966, 1933, 9163, 1016,
2681, 488, 4097, 7994, 894, 7431, 6905, 7776, 9470, 7030,
7449, 5796, 4968, 7848, 4854, 3788, 271, 5765, 8596, 7714,
444, 5178, 6702, 1300, 9066, 1917, 4719, 1081, 7725, 1495,
6528, 9623, 5429, 2632, 514, 3083, 5675, 7973, 3417, 372,
3874, 3729, 6457, 1232, 2056, 2976, 5392, 6042, 6441, 9775,
1050, 6402, 9383, 7218, 4175, 3380, 5158, 8259, 2430, 6894,
8436, 3108, 4944, 5539, 6614, 6772, 4468, 7941, 5276, 3394,
5536, 7984, 3876, 4931, 7805, 1062, 4016, 7206, 2914, 5258,
3633, 6294, 8357, 9238, 9274, 6184, 8680, 9463, 133, 9314,
9662, 4361, 4735, 5878, 5383, 537, 8338, 8662, 9536, 8644,
6241, 3810, 6602, 9331, 5295, 5959, 7408, 6655, 3104, 7800,
822, 814, 8754, 7651, 6469, 1279, 7005, 8202, 5625, 6812,
4676, 9148, 9468, 9087, 8427, 8966, 7126, 1868, 7397, 3020,
1256, 5088, 6283, 502, 1026, 7489, 8173, 7117, 352, 8659,
3943, 716, 2539, 1528, 2365, 5005, 2679, 3472, 2817, 2428,
3144, 5592, 2896, 6863, 6229, 3241, 7390, 9567, 3986, 3517,
9010, 4143, 5122, 5690, 2419, 2359, 6197, 8017, 9624, 6996,
8990, 8448, 6190, 5315]
new_words = [words[i] for i in test_index]
new_targets = [targets[i] for i in test_index]
return new_words, new_targets
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Get examples for humans")
parser.add_argument("--corpus",
type=str,
choices=list(available_corpora.keys()),
default='sst',
help="corpus type")
parser.add_argument("--corpus-path",
type=str,
default=None,
help="path to the corpus",
required=True)
parser.add_argument("--output-path",
type=str,
default=None,
help="path to the output file",
required=True)
parser.add_argument("--nb-examples",
type=int,
default=2000,
help="Number of examples to humman annotation")
args = parser.parse_args()
random.seed(42)
corpus_cls = available_corpora[args.corpus]
fields_tuples = corpus_cls.create_fields_tuples()
print('Reading corpus...')
kwargs_corpus = {}
corpus = corpus_cls(fields_tuples, lazy=True, **kwargs_corpus)
corpus_targets = []
corpus_words = []
for i, ex in enumerate(corpus.read(args.corpus_path)):
if ex.target is not None:
corpus_targets.extend(ex.target)
words_str = ' '.join(ex.words)
words_str += ' ||| '
words_str += ' '.join(ex.words_hyp)
words_str += ' ||| '
words_str += ' '.join(list(map(str, ex.marks[0])))
corpus_words.append(words_str)
corpus.close()
print('Selecting {} instances...'.format(args.nb_examples))
sel_corpus_words, sel_corpus_targets = select_indexes_stratified(
corpus_words, corpus_targets, args.nb_examples
)
print('Saving corpus...')
output_file = Path(args.output_path)
with output_file.open('w', encoding='utf8') as f:
for words, target in zip(sel_corpus_words, sel_corpus_targets):
line = '%s\t%s' % (target, words)
f.write(line + '\n')