Skip to content
This repository has been archived by the owner on Nov 15, 2023. It is now read-only.

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
rafal committed Oct 20, 2016
1 parent 219d302 commit 79ae890
Show file tree
Hide file tree
Showing 16 changed files with 795,411 additions and 0 deletions.
793,470 changes: 793,470 additions & 0 deletions 1b_word_vocab.txt

Large diffs are not rendered by default.

Empty file added __init__.py
Empty file.
97 changes: 97 additions & 0 deletions common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import time
import numpy as np
import tensorflow as tf


def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"):
def _assign(op):
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
if node_def.op == "Variable":
return ps_dev
else:
return "/gpu:%d" % gpu
return _assign


def find_trainable_variables(key):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, ".*{}.*".format(key))


def load_from_checkpoint(saver, logdir):
sess = tf.get_default_session()
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
if os.path.isabs(ckpt.model_checkpoint_path):
# Restores from checkpoint with absolute path.
saver.restore(sess, ckpt.model_checkpoint_path)
else:
# Restores from checkpoint with relative path.
saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path))
return True
return False


class CheckpointLoader(object):
def __init__(self, saver, global_step, logdir):
self.saver = saver
self.global_step_tensor = global_step
self.logdir = logdir
# TODO(rafal): make it restart-proof?
self.last_global_step = 0

def load_checkpoint(self):
while True:
if load_from_checkpoint(self.saver, self.logdir):
global_step = int(self.global_step_tensor.eval())
if global_step <= self.last_global_step:
print("Waiting for a new checkpoint...")
time.sleep(60)
continue
print("Succesfully loaded model at step=%s." % global_step)
else:
print("No checkpoint file found. Waiting...")
time.sleep(60)
continue
self.last_global_step = global_step
return True


def average_grads(tower_grads):
def average_dense(grad_and_vars):
if len(grad_and_vars) == 1:
return grad_and_vars[0][0]

grad = grad_and_vars[0][0]
for g, _ in grad_and_vars[1:]:
grad += g
return grad / len(grad_and_vars)

def average_sparse(grad_and_vars):
if len(grad_and_vars) == 1:
return grad_and_vars[0][0]

indices = []
values = []
for g, _ in grad_and_vars:
indices += [g.indices]
values += [g.values]
indices = tf.concat(0, indices)
values = tf.concat(0, values)
return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape)

average_grads = []
for grad_and_vars in zip(*tower_grads):
if grad_and_vars[0][0] is None:
grad = None
elif isinstance(grad_and_vars[0][0], tf.IndexedSlices):
grad = average_sparse(grad_and_vars)
else:
grad = average_dense(grad_and_vars)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
136 changes: 136 additions & 0 deletions data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import codecs
import glob
import json
import random

import numpy as np


class Vocabulary(object):

def __init__(self):
self._token_to_id = {}
self._token_to_count = {}
self._id_to_token = []
self._num_tokens = 0
self._s_id = None
self._unk_id = None

@property
def num_tokens(self):
return self._num_tokens

@property
def unk(self):
return "<UNK>"

@property
def unk_id(self):
return self._unk_id

@property
def s(self):
return "<S>"

@property
def s_id(self):
return self._s_id

def add(self, token, count):
self._token_to_id[token] = self._num_tokens
self._token_to_count[token] = count
self._id_to_token.append(token)
self._num_tokens += 1

def finalize(self):
self._s_id = self.get_id(self.s)
self._unk_id = self.get_id(self.unk)

def get_id(self, token):
return self._token_to_id.get(token, self.unk_id)

def get_token(self, id_):
return self._id_to_token[id_]

@staticmethod
def from_file(filename):
vocab = Vocabulary()
with codecs.open(filename, "r", "utf-8") as f:
for line in f:
word, count = line.strip().split()
vocab.add(word, int(count))
vocab.finalize()
return vocab


class Dataset(object):

def __init__(self, vocab, file_pattern, deterministic=False):
self._vocab = vocab
self._file_pattern = file_pattern
self._deterministic = deterministic

def _parse_sentence(self, line):
s_id = self._vocab.s_id
return [s_id] + [self._vocab.get_id(word) for word in line.strip().split()] + [s_id]

def _parse_file(self, file_name):
print("Processing file: %s" % file_name)
with codecs.open(file_name, "r", "utf-8") as f:
lines = [line.strip() for line in f]
if not self._deterministic:
random.shuffle(lines)
print("Finished processing!")
for line in lines:
yield self._parse_sentence(line)

def _sentence_stream(self, file_stream):
for file_name in file_stream:
for sentence in self._parse_file(file_name):
yield sentence

def _iterate(self, sentences, batch_size, num_steps):
streams = [None] * batch_size
x = np.zeros([batch_size, num_steps], np.int32)
y = np.zeros([batch_size, num_steps], np.int32)
w = np.zeros([batch_size, num_steps], np.uint8)
while True:
x[:] = 0
y[:] = 0
w[:] = 0
for i in range(batch_size):
tokens_filled = 0
try:
while tokens_filled < num_steps:
if streams[i] is None or len(streams[i]) <= 1:
streams[i] = next(sentences)
num_tokens = min(len(streams[i]) - 1, num_steps - tokens_filled)
x[i, tokens_filled:tokens_filled+num_tokens] = streams[i][:num_tokens]
y[i, tokens_filled:tokens_filled + num_tokens] = streams[i][1:num_tokens+1]
w[i, tokens_filled:tokens_filled + num_tokens] = 1
streams[i] = streams[i][num_tokens:]
tokens_filled += num_tokens
except StopIteration:
pass
if not np.any(w):
return

yield x, y, w

def iterate_once(self, batch_size, num_steps):
def file_stream():
for file_name in glob.glob(self._file_pattern):
yield file_name
for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps):
yield value

def iterate_forever(self, batch_size, num_steps):
def file_stream():
while True:
file_patterns = glob.glob(self._file_pattern)
if not self._deterministic:
random.shuffle(file_patterns)
for file_name in file_patterns:
yield file_name
for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps):
yield value
35 changes: 35 additions & 0 deletions data_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import unittest

from data_utils import Vocabulary, Dataset


class DataUtilsTestCase(unittest.TestCase):
def test_vocabulary(self):
vocab = Vocabulary.from_file("testdata/test_vocab.txt")
self.assertEqual(vocab.num_tokens, 1000)
self.assertEqual(vocab.s_id, 2)
self.assertEqual(vocab.s, "<S>")
self.assertEqual(vocab.unk_id, 38)
self.assertEqual(vocab.unk, "<UNK>")

def test_dataset(self):
vocab = Vocabulary.from_file("testdata/test_vocab.txt")
dataset = Dataset(vocab, "testdata/*")

def generator():
for i in range(1, 10):
yield [0] + list(range(1, i + 1)) + [0]
counts = [0] * 10
for seq in generator():
for v in seq:
counts[v] += 1

counts2 = [0] * 10
for x, y, w in dataset._iterate(generator(), 2, 4):
for v in x.ravel():
counts2[v] += 1
for i in range(1, 10):
self.assertEqual(counts[i], counts2[i], "Mismatch at i=%d" % i)

if __name__ == '__main__':
unittest.main()
30 changes: 30 additions & 0 deletions hparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class HParams(object):

def __init__(self, **kwargs):
self._items = {}
for k, v in kwargs.items():
self._set(k, v)

def _set(self, k, v):
self._items[k] = v
setattr(self, k, v)

def parse(self, str_value):
hps = HParams(**self._items)
for entry in str_value.strip().split(","):
entry = entry.strip()
if not entry:
continue
key, sep, value = entry.partition("=")
if not sep:
raise ValueError("Unable to parse: %s" % entry)
default_value = hps._items[key]
if isinstance(default_value, bool):
hps._set(key, value.lower() == "true")
elif isinstance(default_value, int):
hps._set(key, int(value))
elif isinstance(default_value, float):
hps._set(key, float(value))
else:
hps._set(key, value)
return hps
24 changes: 24 additions & 0 deletions hparams_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import unittest
from hparams import HParams


class HParamsTestCase(unittest.TestCase):
def test_basic(self):
hps = HParams(int_value=13, float_value=17.5, bool_value=True, str_value="test")
self.assertEqual(hps.int_value, 13)
self.assertEqual(hps.float_value, 17.5)
self.assertEqual(hps.bool_value, True)
self.assertEqual(hps.str_value, "test")

def test_parse(self):
hps = HParams(int_value=13, float_value=17.5, bool_value=True, str_value="test")
self.assertEqual(hps.parse("int_value=10").int_value, 10)
self.assertEqual(hps.parse("float_value=10").float_value, 10)
self.assertEqual(hps.parse("float_value=10.3").float_value, 10.3)
self.assertEqual(hps.parse("bool_value=true").bool_value, True)
self.assertEqual(hps.parse("bool_value=True").bool_value, True)
self.assertEqual(hps.parse("bool_value=false").bool_value, False)
self.assertEqual(hps.parse("str_value=value").str_value, "value")

if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 79ae890

Please sign in to comment.