Skip to content

Commit

Permalink
Merge pull request #52 from jelmervdl/trainer-clean-up
Browse files Browse the repository at this point in the history
Trainer clean up
  • Loading branch information
XapaJIaMnu authored Dec 12, 2022
2 parents cdda47b + e0c2326 commit 9533d52
Show file tree
Hide file tree
Showing 9 changed files with 1,038 additions and 373 deletions.
23 changes: 14 additions & 9 deletions trainer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ Define your training process via a configuration file. You define the datasets o
```yml
# Datasets are already TSV files
datasets:
- test/data/clean
- test/data/medium
- test/data/dirty
clean: test/data/clean
medium: test/data/medium
dirty: test/data/dirty

stages:
- start
Expand All @@ -35,25 +35,30 @@ end:
- dirty 0.3
- until dirty 5 # use `inf` to mean until forever

uppercase: 0.05 # Apply uppercase randomly to 0.05% of sentences. Use 0 to disable
titlecase: 0.05 # Apply titlecase randomly to 0.05% of sentences. Use 0 to disable
modifiers:
- uppercase 0.05 # Apply uppercase randomly to 0.05% of sentences. Use 0 to disable
- titlecase 0.05 # Apply titlecase randomly to 0.05% of sentences. Use 0 to disable

seed: 1111
trainer: /path/to/trainer/run.py
```
## Usage
```bash
% ./trainer.py --help :(
usage: trainer.py [-h] --config CONFIG [--temporary-dir TEMPORARY_DIR] [--do-not-resume]
% ./trainer.py --help
usage: trainer.py [-h] --config CONFIG [--temporary-directory TEMPORARY_DIR] [--state STATE_FILE] [--do-not-resume] [--sync] [trainer-command [arguments]]

Feeds marian tsv data for training.

options:
-h, --help show this help message and exit
--config CONFIG, -c CONFIG
YML configuration input.
--temporary-dir TEMPORARY_DIR, -t TEMPORARY_DIR
--temporary-directory TEMPORARY_DIR, -t TEMPORARY_DIR
Temporary dir, used for shuffling and tracking state
--state STATE_FILE Path to trainer state file which stores how much of
each dataset has been read. Defaults to ${CONFIG}.state
--sync Do not shuffle in the background
--do-not-resume, -d Do not resume from the previous training state
```
Once you fix the paths in the configuration file, `train_config.yml` you can run a test case by doing:
Expand All @@ -62,7 +67,7 @@ Once you fix the paths in the configuration file, `train_config.yml` you can run
```
You can check resulting mixed file in `/tmp/test`. If your neural network trainer doesn't support training from `stdin`, you can use this tool to generate a training dataset and then disable data reordering or shuffling at your trainer implementation, as your training input should be balanced.

At the start of the training all datasets are shuffled. Each time a dataset's end is reached, it is re-shuffled. Shuffling happens inside the training directory (by default `./TMP`) where the training state is also kept. If training is interrupted, re-running the trainer should resume from where it was (ALMOST, in case the buffer wasn't consumed by the neural network trainer, it will be skipped, but this is usually only a few hundred sentence pairs, no more).
At the start of the training all datasets are shuffled. Each time a dataset's end is reached, it is re-shuffled. Shuffling [in the system temp directory](https://docs.python.org/3.11/library/tempfile.html#tempfile.gettempdir) but can be repositioned using `--temporary-directory` or the `TMPDIR` environment variable. By default, the training state is kept in the same place as the configuration file. If training is interrupted, re-running the trainer should resume from where it was (depending on how much your neural network trainer has buffered, that part will be skipped).

## Generating vocabulary and placeholders before training
To use the placeholder code augment your training data with placeholders before training, look at this example script:
Expand Down
3 changes: 0 additions & 3 deletions trainer/random.sh

This file was deleted.

167 changes: 167 additions & 0 deletions trainer/shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#!/usr/bin/env python3
import subprocess
import os
from shutil import which
from argparse import ArgumentParser, FileType
from itertools import islice, chain
from tempfile import mkstemp
from typing import TypeVar, Iterable, List, Optional
from queue import Queue
from threading import Thread
from dataclasses import dataclass
from random import Random


# Buffer size for reading files. Bufsize that Python assigns is generally too small?
BUFSIZE=2**16

# Prefer pigz if available, but fall back to calling gzip
PATH_TO_GZIP = which("pigz") or which("gzip")


T = TypeVar('T')

def chunked(iterable: Iterable[T], chunk_size:int) -> Iterable[List[T]]:
"""Splits an iterable into shorter lists of a fixed length."""
it = iter(iterable)
while True:
chunk = list(islice(it, chunk_size))
if not chunk:
break
yield chunk


@dataclass(frozen=True)
class ShuffleTask:
"""Job that describes to shuffle a chunk to the shuffle_chunk_worker thread.
Passing along the seed created by the main thread because those
random.random() calls are predictable. The order in which Shuffling tasks
are picked up and finished may not be."""
fileno: int
seed: float
chunk: List[bytes]


def shuffle_chunk_worker(queue:"Queue[Optional[ShuffleTask]]"):
"""Worker thread that takes a queue of filenames and seeds, and shuffles them
in memory. Put a None in the queue to make it stop."""
while True:
task = queue.get()

if task is None:
break

random = Random(task.seed)

with os.fdopen(task.fileno, 'wb', buffering=BUFSIZE) as fh:
random.shuffle(task.chunk)
fh.writelines(task.chunk)


def shuffle(fin: Iterable[bytes], lines:int, *, seed:Optional[int]=None, threads:int=1, tmpdir:Optional[str]=None) -> Iterable[bytes]:
"""Shuffle a list by reading it into a bunch of files (of `lines` length)
and shuffling all of these with `threads` in-memory shufflers."""
random = Random(seed)

# Limiting queue to 1 pending chunk otherwise we'll run out of memory quickly.
queue: "Queue[Optional[ShuffleTask]]" = Queue(maxsize=threads)

chunks: List[str] = []

try:
# Prepare shuffle workers to start shuffling chunks as soon as we've
# finished writing them.
shufflers = [
Thread(target=shuffle_chunk_worker, args=[queue])
for _ in range(threads)
]

try:
for shuffler in shufflers:
shuffler.start()

# Split the input file into separate temporary chunks
for chunk in chunked(fin, lines):
fileno, filename = mkstemp(dir=tmpdir)
# Remember the chunk's filename for later
chunks.append(filename)
# And immediately start shuffling & writing that chunk in another thread
# so we can use this thread to continue ingesting chunks
queue.put(ShuffleTask(fileno, random.random(), chunk))
finally:
# Tell shufflers that they can stop waiting
for _ in shufflers:
queue.put(None)

# Wait for them to finish shuffling the last files
for shuffler in shufflers:
shuffler.join()

# Open all chunks. We'll be reading the next line from a random one of them.
chunk_fds = [open(filename, 'rb', buffering=BUFSIZE) for filename in chunks]

# While we still have chunks to read from...
while chunk_fds:
# Pick a random chunk, read the line
fd = random.choice(chunk_fds)
line = fd.readline()
# If the line was empty, the chunk has reached EOF and we close it.
if line == b'':
fd.close()
chunk_fds.remove(fd)
continue
yield line
finally:
# Whatever happened, if a filename of a temporary file made it into the
# `chunks` list, we are responsible for cleaning it up.
for filename in chunks:
os.unlink(filename)


class Reader(Iterable[bytes]):
"""Lazily opens a file only once you start trying to read it. Also magically
reads gzipped files."""
def __init__(self, filename:str):
self.filename = filename

def _read_gzip(self, filename:str) -> Iterable[bytes]:
"""Open gzipped files through gzip subprocess. It is faster than Python's
gzip submodule, and you get a bit of multiprocessing for free as the
external gzip process can decompress up to BUFSIZE while python is doing
other things."""
assert PATH_TO_GZIP is not None, 'No gzip executable found on system'
child = subprocess.Popen([PATH_TO_GZIP, '-cd', filename], stdout=subprocess.PIPE, bufsize=BUFSIZE)
assert child.stdout is not None
yield from child.stdout
if child.wait() != 0:
raise RuntimeError(f'`gzip -cd {filename}` failed with return code {child.returncode}')

def _read_plain(self, filename:str) -> Iterable[bytes]:
with open(filename, 'rb') as fh:
yield from fh

def __iter__(self) -> Iterable[bytes]:
if self.filename.endswith('.gz'):
return self._read_gzip(self.filename)
else:
return self._read_plain(self.filename)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--batch-size', type=int, default=1_000_000, help='number of lines per chunk. Note that these chunks are read into memory when being shuffled')
parser.add_argument('--threads', '-j', type=int, default=2, help=f'number of concurrent shuffle threads. Defaults to 2')
parser.add_argument('--temporary-directory', '-T', type=str, help='temporary directory for shuffling batches')
parser.add_argument('seed', type=int)
parser.add_argument('output', type=FileType('wb', bufsize=BUFSIZE), default='-')
parser.add_argument('files', nargs='+')

args = parser.parse_args()

# Read the lines
it = chain.from_iterable(Reader(filename) for filename in args.files)

# Shuffle the lines
it = shuffle(it, lines=args.batch_size, seed=args.seed, threads=args.threads, tmpdir=args.temporary_directory)

args.output.writelines(it)
Loading

0 comments on commit 9533d52

Please sign in to comment.