-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpreprocess.py
110 lines (81 loc) · 3.27 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import codecs
import os
import glob
import sys
from collections import Counter, defaultdict
from itertools import chain, count
from pathlib import Path
import torch
import torchtext
import mtdg
import mtdg.data
import mtdg.inputters.text_dataset
import mtdg.opts as opts
project_dir = Path(__file__).resolve().parent
datasets_dir = project_dir.joinpath('data/')
ubuntu_dir = datasets_dir.joinpath('ubuntu/')
ubuntu_meta_dir = ubuntu_dir.joinpath('meta/')
dialogs_dir = ubuntu_dir.joinpath('dialogs/')
def parse_args():
parser = argparse.ArgumentParser(description='preprocess.py')
opts.preprocess_opts(parser)
opt = parser.parse_args()
torch.manual_seed(opt.seed)
return opt
def build_save_dataset(corpus_type, fields, opt, save=True):
assert corpus_type in ['train', 'valid', 'test']
"""
Process the text corpus into example_dict iterator.
"""
if corpus_type == 'train':
corpus_file = opt.train_data
elif corpus_type == 'test':
corpus_file = opt.test_data
elif corpus_type == 'valid':
corpus_file = opt.valid_data
conversations = mtdg.text_dataset.read_ubuntu_convs(corpus_file,
min_turn=opt.min_turn_length, max_turn=opt.max_turn_length,
min_seq=opt.min_seq_length, max_seq=opt.max_seq_length,
n_workers=opt.n_workers)
# elif opt.data == "dailydialog":
# if corpus_type == 'train':
# corpus_file = opt.train_data
# elif corpus_type == 'valid':
# corpus_file = opt.valid_data
# else:
# corpus_file = opt.test_data
# conversations = mtdg.data.read_dailydialog_file(corpus_file, opt.max_turns, opt.max_seq_length)
# = mtdg.data.read_dailydialog_file(tgt_corpus, opt.max_turns, opt.tgt_seq_length, "tgt")
dataset = mtdg.data.Dataset(conversations, fields)
if save:
dataset.fields = []
print("{:s}.{:s}.pt".format(opt.save_data, corpus_type))
torch.save(dataset, "{:s}.{:s}.pt".format(opt.save_data, corpus_type))
return dataset
def build_save_vocab(train_dataset, fields, opt, save=True):
# We've empty'ed each dataset's `fields` attribute
# when saving datasets, so restore them.
train_dataset.fields = fields
fields["conversation"].build_vocab(train_dataset, max_size=opt.vocab_size,
min_freq=opt.words_min_frequency)
if save:
# Can't save fields, so remove/reconstruct at training time.
torch.save(mtdg.data.save_fields_to_vocab(fields), opt.save_data + '.vocab.pt')
return fields
def main():
opt = parse_args()
print('Preparing for training ...')
fields = mtdg.text_dataset.get_fields(opt)
print("Building & saving training data...")
train_dataset = build_save_dataset('train', fields, opt)
print("Building & saving vocabulary...")
fields = build_save_vocab(train_dataset, fields, opt)
print("Building & saving validation data...")
build_save_dataset('valid', fields, opt)
print("Building & saving test data...")
build_save_dataset("test", fields, opt)
if __name__ == "__main__":
main()