Skip to content

Commit

Permalink
Subgraph refactor (#2278)
Browse files Browse the repository at this point in the history
* Rebase with latest main

* refactor subgraph code

* Enable subgraph parallelism unit tests

* code cleanup

* code cleanup and resolve comments

* clang formatting

* Apply suggestions from code review

Co-authored-by: Tal Ben-Nun <[email protected]>

---------

Co-authored-by: Brian Van Essen <[email protected]>
Co-authored-by: Tal Ben-Nun <[email protected]>
  • Loading branch information
3 people authored Aug 22, 2023
1 parent 662e222 commit e8cf85e
Show file tree
Hide file tree
Showing 25 changed files with 2,940 additions and 2,955 deletions.
89 changes: 52 additions & 37 deletions applications/nlp/transformer/subgraph/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import numpy as np
import torchnlp.datasets

def env2int(env_list, default = -1):
for e in env_list:
val = int(os.environ.get(e, -1))
if val >= 0: return val
return default

data_size = env2int(['DATA_SIZE'])
def env2int(env_list, default=-1):
for e in env_list:
val = int(os.environ.get(e, -1))
if val >= 0:
return val
return default


data_size = env2int(["DATA_SIZE"])
# Local imports
current_file = os.path.realpath(__file__)
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
Expand Down Expand Up @@ -41,25 +44,26 @@ def env2int(env_list, default = -1):
)


if(data_size!=-1):

if data_size != -1:
dataset_train = dataset_train[:data_size]
dataset_val = dataset_val[:1024]

# Load token vocabulary
with open(os.path.join(data_dir, 'vocab.bpe.32000')) as f:
token_file = os.path.join(data_dir, "vocab.bpe.32000")
with open(token_file, "r", encoding="utf-8") as f:
tokens = f.read().splitlines()
tokens.extend(['<unk>', '<s>', '</s>', '<pad>'])
tokens.extend(["<unk>", "<s>", "</s>", "<pad>"])
token_indices = dict(zip(tokens, range(len(tokens))))
unk_index = token_indices.get('<unk>', -1)
bos_index = token_indices.get('<s>', -1)
eos_index = token_indices.get('</s>', -1)
pad_index = token_indices.get('<pad>', -1)
unk_index = token_indices.get("<unk>", -1)
bos_index = token_indices.get("<s>", -1)
eos_index = token_indices.get("</s>", -1)
pad_index = token_indices.get("<pad>", -1)

# ----------------------------------------------
# Tokenization
# ----------------------------------------------


def tokenize(text):
"""Convert string to list of token indices.
Expand All @@ -68,34 +72,34 @@ def tokenize(text):
"""
indices = [bos_index]
indices.extend(
token_indices.get(token, unk_index)
for token in text.split(' ')
)
indices.extend(token_indices.get(token, unk_index) for token in text.split(" "))
indices.append(eos_index)
return indices


def detokenize(indices):
"""Convert token indices to string.
Stops at the first EOS token. All other special tokens are
ignored.
"""
text = ''
text = ""
for index in indices:
if index == eos_index:
break
elif index in (unk_index, bos_index, pad_index):
continue
else:
text += f' {tokens[index]}'
text += f" {tokens[index]}"
return text


# ----------------------------------------------
# Sample access functions
# ----------------------------------------------


def get_train_sample(index):
"""Token indices for a data sample from the training set.
Expand All @@ -106,26 +110,28 @@ def get_train_sample(index):

# Tokenize text data
text = dataset_train[index]
sample_en = tokenize(text['en'])
sample_de = tokenize(text['de'])
sample_en = tokenize(text["en"])
sample_de = tokenize(text["de"])

# Randomly subsample sequences if they are too long
if len(sample_en) > sequence_length or len(sample_de) > sequence_length:
pos = np.random.rand()
if len(sample_en) > sequence_length:
offset = (len(sample_en) - sequence_length + 1) * pos
offset = int(np.floor(offset))
sample_en = sample_en[offset:offset+sequence_length]
sample_en = sample_en[offset : offset + sequence_length]
if len(sample_de) > sequence_length:
offset = (len(sample_de) - sequence_length + 1) * pos
offset = int(np.floor(offset))
sample_de = sample_de[offset:offset+sequence_length]
sample_de = sample_de[offset : offset + sequence_length]

# Concatenate sequences and return
sample = np.full(2*sequence_length, pad_index, dtype=int)
sample[0:len(sample_en)] = sample_en
sample[sequence_length:sequence_length+len(sample_de)] = sample_de
sample = np.full(2 * sequence_length, pad_index, dtype=int)
sample[0 : len(sample_en)] = sample_en
sample[sequence_length : sequence_length + len(sample_de)] = sample_de
return sample


def get_test_sample(index):
"""Token indices for a data sample from the training set.
Expand All @@ -136,38 +142,47 @@ def get_test_sample(index):

# Tokenize text data
text = dataset_train[index]
sample_en = tokenize(text['en'])
sample_de = tokenize(text['de'])
sample_en = tokenize(text["en"])
sample_de = tokenize(text["de"])

# Randomly subsample sequences if they are too long
if len(sample_en) > sequence_length or len(sample_de) > sequence_length:
pos = np.random.rand()
if len(sample_en) > sequence_length:
offset = (len(sample_en) - sequence_length + 1) * pos
offset = int(np.floor(offset))
sample_en = sample_en[offset:offset+sequence_length]
sample_en = sample_en[offset : offset + sequence_length]
if len(sample_de) > sequence_length:
offset = (len(sample_de) - sequence_length + 1) * pos
offset = int(np.floor(offset))
sample_de = sample_de[offset:offset+sequence_length]
sample_de = sample_de[offset : offset + sequence_length]

# Concatenate sequences and return
sample = np.full(2*sequence_length, pad_index, dtype=int)
sample[0:len(sample_en)] = sample_en
sample[sequence_length:sequence_length+len(sample_de)] = sample_de
sample = np.full(2 * sequence_length, pad_index, dtype=int)
sample[0 : len(sample_en)] = sample_en
sample[sequence_length : sequence_length + len(sample_de)] = sample_de
return sample


def get_val_sample(index):
"""Token indices for a data sample from the validation set."""
text = dataset_val[index]
sample_en = tokenize(text['en'])
sample_de = tokenize(text['de'])
sample_en = tokenize(text["en"])
sample_de = tokenize(text["de"])
return sample_en, sample_de


def num_train_samples():
return len(dataset_train)


def num_val_samples():
return len(dataset_val)


def sample_dims():
return (2*sequence_length+1,)
return (2 * sequence_length + 1,)


def vocab_size():
return len(tokens)
Loading

0 comments on commit e8cf85e

Please sign in to comment.