Skip to content

Commit

Permalink
updated prepare data script
Browse files Browse the repository at this point in the history
  • Loading branch information
tlatkowski committed Nov 7, 2019
1 parent 017a68b commit 215f21f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
11 changes: 8 additions & 3 deletions bin/prepare_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@ QQP_TEST_ID=1XD-HxzUCTHrzhfvIXOlgqN_MWiiAqM8h
SNLI_DATA=${SNLI_DATA:-https://drive.google.com/uc?export=download&id=${SNLI_ID}}
QQP_DATA_TRAIN=${QQP_DATA_TRAIN:-https://drive.google.com/uc?export=download&id=${QQP_TRAIN_ID}}
QQP_DATA_TEST=${QQP_DATA_TEST:-https://drive.google.com/uc?export=download&id=${QQP_TEST_ID}}
ANLI_DATA_LINK=https://dl.fbaipublicfiles.com/anli/anli_v0.1.zip

CORPORA_DIR=corpora
SNLI_DIR=SNLI
QQP_DIR=QQP
CORPORA_DIR=corpora
ANLI_DIR=ANLI

SNLI_FILE=train_snli.tgz
QQP_FILE_TRAIN=qqp_train.tgz
QQP_FILE_TEST=qqp_test.tgz
ANLI_FILE=anli_v0.1.zip

mkdir ../${CORPORA_DIR}
cd ../${CORPORA_DIR}
mkdir ${SNLI_DIR} ${QQP_DIR}
mkdir ${SNLI_DIR} ${QQP_DIR} ${ANLI_DIR}

function google_drive_big_file_download () {
CONFIRM=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')
Expand All @@ -29,7 +32,9 @@ function google_drive_big_file_download () {
wget --no-check-certificate ${SNLI_DATA} -O ${SNLI_DIR}/${SNLI_FILE}
wget --no-check-certificate ${QQP_DATA_TRAIN} -O ${QQP_DIR}/${QQP_FILE_TRAIN}
google_drive_big_file_download ${QQP_TEST_ID} ${QQP_DIR}/${QQP_FILE_TEST}
wget --no-check-certificate ${ANLI_DATA_LINK} -O ${ANLI_DIR}/${ANLI_FILE}

tar -xvzf ${SNLI_DIR}/${SNLI_FILE} -C ${SNLI_DIR}
tar -xvzf ${QQP_DIR}/${QQP_FILE_TRAIN} -C ${QQP_DIR}
tar -xvzf ${QQP_DIR}/${QQP_FILE_TEST} -C ${QQP_DIR}
tar -xvzf ${QQP_DIR}/${QQP_FILE_TEST} -C ${QQP_DIR}
unzip ${ANLI_DIR}/${ANLI_FILE} -d ${ANLI_DIR}
4 changes: 2 additions & 2 deletions data/anli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def dev_set_pairs(self):
return self.dev[['hypothesis', 'reason']].as_matrix()

def dev_labels(self):
return self.dev['label'].as_matrix()
return pd.get_dummies(self.dev['label']).as_matrix()

def test_set(self):
return self.test
Expand All @@ -63,7 +63,7 @@ def test_set_pairs(self):
return self.test[['hypothesis', 'reason']].as_matrix()

def test_labels(self):
return self.test['label'].as_matrix()
return pd.get_dummies(self.test['label']).as_matrix()

def _data_path(self):
return 'corpora/ANLI/R3'
9 changes: 8 additions & 1 deletion models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@

class BaseSiameseNet:

def __init__(self, max_sequence_len, vocabulary_size, main_cfg, model_cfg, loss_function):
def __init__(
self,
max_sequence_len,
vocabulary_size,
main_cfg,
model_cfg,
loss_function,
):
self.x1 = tf.placeholder(dtype=tf.int32, shape=[None, max_sequence_len])
self.x2 = tf.placeholder(dtype=tf.int32, shape=[None, max_sequence_len])
self.is_training = tf.placeholder(dtype=tf.bool)
Expand Down

0 comments on commit 215f21f

Please sign in to comment.