Skip to content

Commit

Permalink
Update download and partition of imdb dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
mtreviso committed Jan 18, 2021
1 parent 49db8e9 commit b812860
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
3 changes: 2 additions & 1 deletion generate_dataset_partitions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ python3 scripts/partition_agnews_corpus.py \
# the original training file
mv data/corpus/imdb/train data/corpus/imdb/train-original
mkdir data/corpus/imdb/train
mkdir data/corpus/imdb/dev

# generate dev data from the training data for IMDB
python3 scripts/partition_imdb_corpus.py \
'data/corpus/imdb/train/data.txt' \
'data/corpus/imdb/train-original/' \
'data/corpus/imdb/train/data.txt' \
'data/corpus/imdb/dev/data.txt'

Expand Down
24 changes: 19 additions & 5 deletions scripts/partition_imdb_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""
import random
import sys

from itertools import chain
from pathlib import Path

if __name__ == '__main__':
seed = 42
Expand All @@ -17,9 +18,22 @@
test_output_path = sys.argv[3]

nb_lines = 0
with open(file_input_path, 'r', encoding='utf8') as f:
for _ in f:
nb_lines += 1
neg_files = sorted(Path(file_input_path, 'neg').glob('*.txt'))
pos_files = sorted(Path(file_input_path, 'pos').glob('*.txt'))
paths = chain(neg_files, pos_files)
new_file_path = Path(file_input_path, 'data.txt')
new_file = new_file_path.open('w', encoding='utf8')
for file_path in paths:
content = file_path.read_text().strip()
content = content.replace('<br>', ' <br> ')
content = content.replace('<br >', ' <br> ')
content = content.replace('<br />', ' <br> ')
content = content.replace('<br/>', ' <br> ')
label = '1' if '/pos/' in str(file_path) else '0'
new_file.write(label + ' ' + content + '\n')
nb_lines += 1
new_file.seek(0)
new_file.close()

data_indexes = list(range(nb_lines))
random.shuffle(data_indexes)
Expand All @@ -35,7 +49,7 @@
f_train = open(train_output_path, 'w', encoding='utf8')
f_test = open(test_output_path, 'w', encoding='utf8')
idx = 0
with open(file_input_path, 'r', encoding='utf8') as f:
with open(new_file_path, 'r', encoding='utf8') as f:
for line in f:
print('{}/{}'.format(idx, nb_lines), end='\r')
# ele = json.loads(line.strip())
Expand Down

0 comments on commit b812860

Please sign in to comment.