This repository has been archived by the owner on Nov 15, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
rafal
committed
Oct 20, 2016
1 parent
219d302
commit 79ae890
Showing
16 changed files
with
795,411 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import os | ||
import time | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
|
||
def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"): | ||
def _assign(op): | ||
node_def = op if isinstance(op, tf.NodeDef) else op.node_def | ||
if node_def.op == "Variable": | ||
return ps_dev | ||
else: | ||
return "/gpu:%d" % gpu | ||
return _assign | ||
|
||
|
||
def find_trainable_variables(key): | ||
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, ".*{}.*".format(key)) | ||
|
||
|
||
def load_from_checkpoint(saver, logdir): | ||
sess = tf.get_default_session() | ||
ckpt = tf.train.get_checkpoint_state(logdir) | ||
if ckpt and ckpt.model_checkpoint_path: | ||
if os.path.isabs(ckpt.model_checkpoint_path): | ||
# Restores from checkpoint with absolute path. | ||
saver.restore(sess, ckpt.model_checkpoint_path) | ||
else: | ||
# Restores from checkpoint with relative path. | ||
saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path)) | ||
return True | ||
return False | ||
|
||
|
||
class CheckpointLoader(object): | ||
def __init__(self, saver, global_step, logdir): | ||
self.saver = saver | ||
self.global_step_tensor = global_step | ||
self.logdir = logdir | ||
# TODO(rafal): make it restart-proof? | ||
self.last_global_step = 0 | ||
|
||
def load_checkpoint(self): | ||
while True: | ||
if load_from_checkpoint(self.saver, self.logdir): | ||
global_step = int(self.global_step_tensor.eval()) | ||
if global_step <= self.last_global_step: | ||
print("Waiting for a new checkpoint...") | ||
time.sleep(60) | ||
continue | ||
print("Succesfully loaded model at step=%s." % global_step) | ||
else: | ||
print("No checkpoint file found. Waiting...") | ||
time.sleep(60) | ||
continue | ||
self.last_global_step = global_step | ||
return True | ||
|
||
|
||
def average_grads(tower_grads): | ||
def average_dense(grad_and_vars): | ||
if len(grad_and_vars) == 1: | ||
return grad_and_vars[0][0] | ||
|
||
grad = grad_and_vars[0][0] | ||
for g, _ in grad_and_vars[1:]: | ||
grad += g | ||
return grad / len(grad_and_vars) | ||
|
||
def average_sparse(grad_and_vars): | ||
if len(grad_and_vars) == 1: | ||
return grad_and_vars[0][0] | ||
|
||
indices = [] | ||
values = [] | ||
for g, _ in grad_and_vars: | ||
indices += [g.indices] | ||
values += [g.values] | ||
indices = tf.concat(0, indices) | ||
values = tf.concat(0, values) | ||
return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape) | ||
|
||
average_grads = [] | ||
for grad_and_vars in zip(*tower_grads): | ||
if grad_and_vars[0][0] is None: | ||
grad = None | ||
elif isinstance(grad_and_vars[0][0], tf.IndexedSlices): | ||
grad = average_sparse(grad_and_vars) | ||
else: | ||
grad = average_dense(grad_and_vars) | ||
# Keep in mind that the Variables are redundant because they are shared | ||
# across towers. So .. we will just return the first tower's pointer to | ||
# the Variable. | ||
v = grad_and_vars[0][1] | ||
grad_and_var = (grad, v) | ||
average_grads.append(grad_and_var) | ||
return average_grads |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import codecs | ||
import glob | ||
import json | ||
import random | ||
|
||
import numpy as np | ||
|
||
|
||
class Vocabulary(object): | ||
|
||
def __init__(self): | ||
self._token_to_id = {} | ||
self._token_to_count = {} | ||
self._id_to_token = [] | ||
self._num_tokens = 0 | ||
self._s_id = None | ||
self._unk_id = None | ||
|
||
@property | ||
def num_tokens(self): | ||
return self._num_tokens | ||
|
||
@property | ||
def unk(self): | ||
return "<UNK>" | ||
|
||
@property | ||
def unk_id(self): | ||
return self._unk_id | ||
|
||
@property | ||
def s(self): | ||
return "<S>" | ||
|
||
@property | ||
def s_id(self): | ||
return self._s_id | ||
|
||
def add(self, token, count): | ||
self._token_to_id[token] = self._num_tokens | ||
self._token_to_count[token] = count | ||
self._id_to_token.append(token) | ||
self._num_tokens += 1 | ||
|
||
def finalize(self): | ||
self._s_id = self.get_id(self.s) | ||
self._unk_id = self.get_id(self.unk) | ||
|
||
def get_id(self, token): | ||
return self._token_to_id.get(token, self.unk_id) | ||
|
||
def get_token(self, id_): | ||
return self._id_to_token[id_] | ||
|
||
@staticmethod | ||
def from_file(filename): | ||
vocab = Vocabulary() | ||
with codecs.open(filename, "r", "utf-8") as f: | ||
for line in f: | ||
word, count = line.strip().split() | ||
vocab.add(word, int(count)) | ||
vocab.finalize() | ||
return vocab | ||
|
||
|
||
class Dataset(object): | ||
|
||
def __init__(self, vocab, file_pattern, deterministic=False): | ||
self._vocab = vocab | ||
self._file_pattern = file_pattern | ||
self._deterministic = deterministic | ||
|
||
def _parse_sentence(self, line): | ||
s_id = self._vocab.s_id | ||
return [s_id] + [self._vocab.get_id(word) for word in line.strip().split()] + [s_id] | ||
|
||
def _parse_file(self, file_name): | ||
print("Processing file: %s" % file_name) | ||
with codecs.open(file_name, "r", "utf-8") as f: | ||
lines = [line.strip() for line in f] | ||
if not self._deterministic: | ||
random.shuffle(lines) | ||
print("Finished processing!") | ||
for line in lines: | ||
yield self._parse_sentence(line) | ||
|
||
def _sentence_stream(self, file_stream): | ||
for file_name in file_stream: | ||
for sentence in self._parse_file(file_name): | ||
yield sentence | ||
|
||
def _iterate(self, sentences, batch_size, num_steps): | ||
streams = [None] * batch_size | ||
x = np.zeros([batch_size, num_steps], np.int32) | ||
y = np.zeros([batch_size, num_steps], np.int32) | ||
w = np.zeros([batch_size, num_steps], np.uint8) | ||
while True: | ||
x[:] = 0 | ||
y[:] = 0 | ||
w[:] = 0 | ||
for i in range(batch_size): | ||
tokens_filled = 0 | ||
try: | ||
while tokens_filled < num_steps: | ||
if streams[i] is None or len(streams[i]) <= 1: | ||
streams[i] = next(sentences) | ||
num_tokens = min(len(streams[i]) - 1, num_steps - tokens_filled) | ||
x[i, tokens_filled:tokens_filled+num_tokens] = streams[i][:num_tokens] | ||
y[i, tokens_filled:tokens_filled + num_tokens] = streams[i][1:num_tokens+1] | ||
w[i, tokens_filled:tokens_filled + num_tokens] = 1 | ||
streams[i] = streams[i][num_tokens:] | ||
tokens_filled += num_tokens | ||
except StopIteration: | ||
pass | ||
if not np.any(w): | ||
return | ||
|
||
yield x, y, w | ||
|
||
def iterate_once(self, batch_size, num_steps): | ||
def file_stream(): | ||
for file_name in glob.glob(self._file_pattern): | ||
yield file_name | ||
for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps): | ||
yield value | ||
|
||
def iterate_forever(self, batch_size, num_steps): | ||
def file_stream(): | ||
while True: | ||
file_patterns = glob.glob(self._file_pattern) | ||
if not self._deterministic: | ||
random.shuffle(file_patterns) | ||
for file_name in file_patterns: | ||
yield file_name | ||
for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps): | ||
yield value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import unittest | ||
|
||
from data_utils import Vocabulary, Dataset | ||
|
||
|
||
class DataUtilsTestCase(unittest.TestCase): | ||
def test_vocabulary(self): | ||
vocab = Vocabulary.from_file("testdata/test_vocab.txt") | ||
self.assertEqual(vocab.num_tokens, 1000) | ||
self.assertEqual(vocab.s_id, 2) | ||
self.assertEqual(vocab.s, "<S>") | ||
self.assertEqual(vocab.unk_id, 38) | ||
self.assertEqual(vocab.unk, "<UNK>") | ||
|
||
def test_dataset(self): | ||
vocab = Vocabulary.from_file("testdata/test_vocab.txt") | ||
dataset = Dataset(vocab, "testdata/*") | ||
|
||
def generator(): | ||
for i in range(1, 10): | ||
yield [0] + list(range(1, i + 1)) + [0] | ||
counts = [0] * 10 | ||
for seq in generator(): | ||
for v in seq: | ||
counts[v] += 1 | ||
|
||
counts2 = [0] * 10 | ||
for x, y, w in dataset._iterate(generator(), 2, 4): | ||
for v in x.ravel(): | ||
counts2[v] += 1 | ||
for i in range(1, 10): | ||
self.assertEqual(counts[i], counts2[i], "Mismatch at i=%d" % i) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
class HParams(object): | ||
|
||
def __init__(self, **kwargs): | ||
self._items = {} | ||
for k, v in kwargs.items(): | ||
self._set(k, v) | ||
|
||
def _set(self, k, v): | ||
self._items[k] = v | ||
setattr(self, k, v) | ||
|
||
def parse(self, str_value): | ||
hps = HParams(**self._items) | ||
for entry in str_value.strip().split(","): | ||
entry = entry.strip() | ||
if not entry: | ||
continue | ||
key, sep, value = entry.partition("=") | ||
if not sep: | ||
raise ValueError("Unable to parse: %s" % entry) | ||
default_value = hps._items[key] | ||
if isinstance(default_value, bool): | ||
hps._set(key, value.lower() == "true") | ||
elif isinstance(default_value, int): | ||
hps._set(key, int(value)) | ||
elif isinstance(default_value, float): | ||
hps._set(key, float(value)) | ||
else: | ||
hps._set(key, value) | ||
return hps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import unittest | ||
from hparams import HParams | ||
|
||
|
||
class HParamsTestCase(unittest.TestCase): | ||
def test_basic(self): | ||
hps = HParams(int_value=13, float_value=17.5, bool_value=True, str_value="test") | ||
self.assertEqual(hps.int_value, 13) | ||
self.assertEqual(hps.float_value, 17.5) | ||
self.assertEqual(hps.bool_value, True) | ||
self.assertEqual(hps.str_value, "test") | ||
|
||
def test_parse(self): | ||
hps = HParams(int_value=13, float_value=17.5, bool_value=True, str_value="test") | ||
self.assertEqual(hps.parse("int_value=10").int_value, 10) | ||
self.assertEqual(hps.parse("float_value=10").float_value, 10) | ||
self.assertEqual(hps.parse("float_value=10.3").float_value, 10.3) | ||
self.assertEqual(hps.parse("bool_value=true").bool_value, True) | ||
self.assertEqual(hps.parse("bool_value=True").bool_value, True) | ||
self.assertEqual(hps.parse("bool_value=false").bool_value, False) | ||
self.assertEqual(hps.parse("str_value=value").str_value, "value") | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.