Skip to content

Commit

Permalink
fix: FsNeo4jCSVLoader fails if nodes have disjoint keys (#408)
Browse files Browse the repository at this point in the history
* Prepare for test_fs_neo4j_csv_loader to run multiple tests

Signed-off-by: Joseph Atkins-Turkish <[email protected]>

* Add failing test

Signed-off-by: Joseph Atkins-Turkish <[email protected]>

* Fix failing test

Signed-off-by: Joseph Atkins-Turkish <[email protected]>

* Implement using numeric keys

Signed-off-by: Joseph Atkins-Turkish <[email protected]>
  • Loading branch information
Joseph Atkins-Turkish authored Nov 17, 2020
1 parent 8a9618e commit c07cec9
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 18 deletions.
11 changes: 8 additions & 3 deletions databuilder/loader/file_system_neo4j_csv_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from csv import DictWriter

from pyhocon import ConfigTree, ConfigFactory
from typing import Dict, Any
from typing import Dict, Any, FrozenSet

from databuilder.job.base_job import Job
from databuilder.loader.base_loader import Loader
Expand Down Expand Up @@ -40,6 +40,7 @@ class FsNeo4jCSVLoader(Loader):
def __init__(self) -> None:
self._node_file_mapping: Dict[Any, DictWriter] = {}
self._relation_file_mapping: Dict[Any, DictWriter] = {}
self._keys: Dict[FrozenSet[str], int] = {}
self._closer = Closer()

def init(self, conf: ConfigTree) -> None:
Expand Down Expand Up @@ -109,7 +110,7 @@ def load(self, csv_serializable: GraphSerializable) -> None:
node = csv_serializable.next_node()
while node:
node_dict = neo4_serializer.serialize_node(node)
key = (node.label, len(node_dict))
key = (node.label, self._make_key(node_dict))
file_suffix = '{}_{}'.format(*key)
node_writer = self._get_writer(node_dict,
self._node_file_mapping,
Expand All @@ -125,7 +126,7 @@ def load(self, csv_serializable: GraphSerializable) -> None:
key2 = (relation.start_label,
relation.end_label,
relation.type,
len(relation_dict))
self._make_key(relation_dict))

file_suffix = '{}_{}_{}'.format(key2[0], key2[1], key2[2])
relation_writer = self._get_writer(relation_dict,
Expand Down Expand Up @@ -183,3 +184,7 @@ def close(self) -> None:

def get_scope(self) -> str:
return "loader.filesystem_csv_neo4j"

def _make_key(self, record_dict: Dict[str, Any]) -> int:
""" Each unique set of record keys is assigned an increasing numeric key """
return self._keys.setdefault(frozenset(record_dict.keys()), len(self._keys))
102 changes: 87 additions & 15 deletions tests/unit/loader/test_fs_neo4j_csv_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from os import listdir
from os.path import isfile, join

from pyhocon import ConfigFactory
from typing import Dict, Iterable, Any, Callable
from pyhocon import ConfigFactory, ConfigTree
from typing import Dict, Iterable, Any, Callable, Optional, Union

from databuilder.models.graph_serializable import GraphSerializable, GraphNode, GraphRelationship
from databuilder.job.base_job import Job
from databuilder.loader.file_system_neo4j_csv_loader import FsNeo4jCSVLoader
from tests.unit.models.test_graph_serializable import Movie, Actor, City
Expand All @@ -21,12 +22,6 @@
class TestFsNeo4jCSVLoader(unittest.TestCase):
def setUp(self) -> None:
logging.basicConfig(level=logging.INFO)
prefix = '/var/tmp/TestFsNeo4jCSVLoader'
self._conf = ConfigFactory.from_dict(
{FsNeo4jCSVLoader.NODE_DIR_PATH: '{}/{}'.format(prefix, 'nodes'),
FsNeo4jCSVLoader.RELATION_DIR_PATH: '{}/{}'
.format(prefix, 'relationships'),
FsNeo4jCSVLoader.SHOULD_DELETE_CREATED_DIR: True})

def tearDown(self) -> None:
Job.closer.close()
Expand All @@ -37,25 +32,61 @@ def test_load(self) -> None:
movie = Movie('Top Gun', actors, cities)

loader = FsNeo4jCSVLoader()
loader.init(self._conf)

folder = 'movies'
conf = self._make_conf(folder)

loader.init(conf)
loader.load(movie)
loader.close()

expected_node_path = '{}/../resources/fs_neo4j_csv_loader/nodes'\
.format(os.path.join(os.path.dirname(__file__)))
expected_node_path = '{}/../resources/fs_neo4j_csv_loader/{}/nodes'\
.format(os.path.join(os.path.dirname(__file__)), folder)
expected_nodes = self._get_csv_rows(expected_node_path, itemgetter('KEY'))
actual_nodes = self._get_csv_rows(self._conf.get_string(FsNeo4jCSVLoader.NODE_DIR_PATH),
actual_nodes = self._get_csv_rows(conf.get_string(FsNeo4jCSVLoader.NODE_DIR_PATH),
itemgetter('KEY'))
self.assertEqual(expected_nodes, actual_nodes)

expected_rel_path = \
'{}/../resources/fs_neo4j_csv_loader/relationships' \
.format(os.path.join(os.path.dirname(__file__)))
'{}/../resources/fs_neo4j_csv_loader/{}/relationships' \
.format(os.path.join(os.path.dirname(__file__)), folder)
expected_relations = self._get_csv_rows(expected_rel_path, itemgetter('START_KEY', 'END_KEY'))
actual_relations = self._get_csv_rows(self._conf.get_string(FsNeo4jCSVLoader.RELATION_DIR_PATH),
actual_relations = self._get_csv_rows(conf.get_string(FsNeo4jCSVLoader.RELATION_DIR_PATH),
itemgetter('START_KEY', 'END_KEY'))
self.assertEqual(expected_relations, actual_relations)

def test_load_disjoint_properties(self) -> None:
people = [
Person("Taylor", job="Engineer"),
Person("Griffin", pet="Lion"),
]

loader = FsNeo4jCSVLoader()

folder = 'people'
conf = self._make_conf(folder)

loader.init(conf)
loader.load(people[0])
loader.load(people[1])
loader.close()

expected_node_path = '{}/../resources/fs_neo4j_csv_loader/{}/nodes'\
.format(os.path.join(os.path.dirname(__file__)), folder)
expected_nodes = self._get_csv_rows(expected_node_path, itemgetter('KEY'))
actual_nodes = self._get_csv_rows(conf.get_string(FsNeo4jCSVLoader.NODE_DIR_PATH),
itemgetter('KEY'))
self.assertEqual(expected_nodes, actual_nodes)

def _make_conf(self, test_name: str) -> ConfigTree:
prefix = '/var/tmp/TestFsNeo4jCSVLoader'

return ConfigFactory.from_dict(
{FsNeo4jCSVLoader.NODE_DIR_PATH: '{}/{}/{}'.format(prefix, test_name, 'nodes'),
FsNeo4jCSVLoader.RELATION_DIR_PATH: '{}/{}/{}'
.format(prefix, test_name, 'relationships'),
FsNeo4jCSVLoader.SHOULD_DELETE_CREATED_DIR: True})

def _get_csv_rows(self,
path: str,
sorting_key_getter: Callable) -> Iterable[Dict[str, Any]]:
Expand All @@ -71,5 +102,46 @@ def _get_csv_rows(self,
return sorted(result, key=sorting_key_getter)


class Person(GraphSerializable):
""" A Person has multiple optional attributes. When an attribute is None,
it is not included in the resulting node.
"""
LABEL = 'Person'
KEY_FORMAT = 'person://{}'

def __init__(self,
name: str,
*,
pet: Optional[str] = None,
job: Optional[str] = None,
) -> None:
self._name = name
self._pet = pet
self._job = job
self._node_iter = iter(self.create_nodes())

def create_next_node(self) -> Union[GraphNode, None]:
try:
return next(self._node_iter)
except StopIteration:
return None

def create_next_relation(self) -> Union[GraphRelationship, None]:
return None

def create_nodes(self) -> Iterable[GraphNode]:
attributes = {"name": self._name}
if self._pet:
attributes['pet'] = self._pet
if self._job:
attributes['job'] = self._job

return [GraphNode(
key=Person.KEY_FORMAT.format(self._name),
label=Person.LABEL,
attributes=attributes
)]


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"name","job","KEY","LABEL"
"Taylor","Engineer","person://Taylor","Person"
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"name","pet","KEY","LABEL"
"Griffin","Lion","person://Griffin","Person"

0 comments on commit c07cec9

Please sign in to comment.