diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 23757110c4aa8..20008676543d0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -4,12 +4,11 @@ import warnings from abc import ABC, abstractmethod from argparse import Namespace +import csv - -import pandas as pd import torch import torch.distributed as dist -# + from pytorch_lightning.core.decorators import data_loader from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import ModelHooks @@ -1217,10 +1216,12 @@ def load_hparams_from_tags_csv(tags_csv): logging.warning(f'Missing Tags: {tags_csv}.') return Namespace() - tags_df = pd.read_csv(tags_csv) - dic = tags_df.to_dict(orient='records') - ns_dict = {row['key']: convert(row['value']) for row in dic} - ns = Namespace(**ns_dict) + tags = {} + with open(tags_csv) as f: + csv_reader = csv.reader(f, delimiter=',') + for row in list(csv_reader)[1:]: + tags[row[0]] = convert(row[1]) + ns = Namespace(**tags) return ns diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 3437ef3aa6260..269282d5529c2 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -9,7 +9,6 @@ from subprocess import PIPE import numpy as np -import pandas as pd import torch @@ -146,24 +145,14 @@ def make_summary(self): Layer Name, Layer Type, Input Size, Output Size, Number of Parameters ''' - - cols = ['Name', 'Type', 'Params'] - if self.model.example_input_array is not None: - cols.extend(['In_sizes', 'Out_sizes']) - - df = pd.DataFrame(np.zeros((len(self.layer_names), len(cols)))) - df.columns = cols - - df['Name'] = self.layer_names - df['Type'] = self.layer_types - df['Params'] = self.param_nums - df['Params'] = df['Params'].map(get_human_readable_count) - + arrays = [['Name', self.layer_names], + ['Type', self.layer_types], + ['Params', list(map(get_human_readable_count, self.param_nums))]] if self.model.example_input_array is not None: - df['In_sizes'] = self.in_sizes - df['Out_sizes'] = self.out_sizes + arrays.append(['In sizes', self.in_sizes]) + arrays.append(['Out sizes', self.out_sizes]) - self.summary = df + self.summary = _format_summary_table(*arrays) return def summarize(self): @@ -176,6 +165,51 @@ def summarize(self): self.make_summary() +def _format_summary_table(*cols): + ''' + Takes in a number of arrays, each specifying a column in + the summary table, and combines them all into one big + string defining the summary table that are nicely formatted. + ''' + n_rows = len(cols[0][1]) + n_cols = 1 + len(cols) + + # Layer counter + counter = list(map(str, list(range(n_rows)))) + counter_len = max([len(c) for c in counter]) + + # Get formatting length of each column + length = [] + for c in cols: + str_l = len(c[0]) # default length is header length + for a in c[1]: + if isinstance(a, np.ndarray): + array_string = '[' + ', '.join([str(j) for j in a]) + ']' + str_l = max(len(array_string), str_l) + else: + str_l = max(len(a), str_l) + length.append(str_l) + + # Formatting + s = '{:<{}}' + full_length = sum(length) + 3 * n_cols + header = [s.format(' ', counter_len)] + [s.format(c[0], l) for c, l in zip(cols, length)] + + # Summary = header + divider + Rest of table + summary = ' | '.join(header) + '\n' + '-' * full_length + for i in range(n_rows): + line = s.format(counter[i], counter_len) + for c, l in zip(cols, length): + if isinstance(c[1][i], np.ndarray): + array_string = '[' + ', '.join([str(j) for j in c[1][i]]) + ']' + line += ' | ' + array_string + ' ' * (l - len(array_string)) + else: + line += ' | ' + s.format(c[1][i], l) + summary += '\n' + line + + return summary + + def print_mem_stack(): # pragma: no cover for obj in gc.get_objects(): try: diff --git a/pytorch_lightning/logging/tensorboard.py b/pytorch_lightning/logging/tensorboard.py index 73862a0755bed..63df5a6443025 100644 --- a/pytorch_lightning/logging/tensorboard.py +++ b/pytorch_lightning/logging/tensorboard.py @@ -4,7 +4,7 @@ from pkg_resources import parse_version import torch -import pandas as pd +import csv from torch.utils.tensorboard import SummaryWriter from .base import LightningLoggerBase, rank_zero_only @@ -108,12 +108,17 @@ def save(self): dir_path = os.path.join(self.save_dir, self.name, 'version_%s' % self.version) if not os.path.isdir(dir_path): dir_path = self.save_dir + # prepare the file path meta_tags_path = os.path.join(dir_path, self.NAME_CSV_TAGS) + # save the metatags file - df = pd.DataFrame({'key': list(self.tags.keys()), - 'value': list(self.tags.values())}) - df.to_csv(meta_tags_path, index=False) + with open(meta_tags_path, 'w', newline='') as csvfile: + fieldnames = ['key', 'value'] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writerow({'key': 'key', 'value': 'value'}) + for k, v in self.tags.items(): + writer.writerow({'key': k, 'value': v}) @rank_zero_only def finalize(self, status): diff --git a/requirements.txt b/requirements.txt index 87f915c8a7d06..90cec5a5f5606 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,6 @@ tqdm>=4.35.0 numpy>=1.16.4 torch>=1.1 torchvision>=0.4.0, < 0.5 # the 0.5. has some issues with torch JIT -pandas>=0.24 # lower version do not support py3.7 tensorboard>=1.14 -future>=0.17.1 # required for builtins in setup.py \ No newline at end of file +future>=0.17.1 # required for builtins in setup.py +