Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[API] unified API for custom kvstores #17010

Merged
merged 37 commits into from
Dec 17, 2019
Merged
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add unit test for kvstore base
Ubuntu committed Dec 7, 2019

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 339fdd2dc4d7b7cc60f7f266149dd129d80b3b1c
4 changes: 3 additions & 1 deletion python/mxnet/kvstore/kvstore.py
Original file line number Diff line number Diff line change
@@ -75,7 +75,9 @@ def __del__(self):

def broadcast(self, key, value, out, priority=0):
""" Broadcast the `value` NDArray at rank 0 to all ranks,
and store the result in `out`
and store the result in `out`.

Note that the native KVStore does not support broadcasting the same key more than once.

Parameters
----------
156 changes: 156 additions & 0 deletions tests/python/unittest/test_kvstore_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: skip-file
import mxnet as mx
import numpy as np
import unittest
from mxnet.test_utils import rand_ndarray, assert_almost_equal
from common import setup_module, with_seed, assertRaises, teardown
from mxnet.base import py_str, MXNetError

shape = (4, 4)
keys = [5, 7, 11]
str_keys = ['b', 'c', 'd']

def check_diff_to_scalar(A, x):
""" assert A == x"""
assert(np.sum(np.abs((A - x).asnumpy())) == 0), (A, x)

def init_kv(name='device'):
return mx.kv.create(name)

@with_seed()
def test_broadcast_single_kv_pair():
"""single key-value pair push & pull"""
def check_single_kv_pair(kv, key):
# single output
ones = mx.nd.ones(shape)
out = mx.nd.empty(shape)
kv.broadcast(key, ones, out)
check_diff_to_scalar(out, 1)
# list output
out_list = [mx.nd.empty(shape)] * 3
key_list = key + key
kv.broadcast(key_list, ones, out_list)
for o in out_list:
check_diff_to_scalar(o, 1)

check_single_kv_pair(init_kv(), 3)
check_single_kv_pair(init_kv(), 'a')

@with_seed()
def test_broadcast_list_kv_pair():
"""list key-value pair push & pull"""
def check_list_kv_pair(kv, key):
ones = [mx.nd.ones(shape)] * len(key)
out = [mx.nd.empty(shape)] * len(key)
kv.broadcast(key, ones, out)
for o in out:
check_diff_to_scalar(o, 1)
out_list = [[mx.nd.empty(shape)] * 2 for _ in range(len(key))]
key_list = [k + k for k in key]
kv.broadcast(key_list, ones, out_list)
for o in out_list:
for oo in o:
check_diff_to_scalar(oo, 1)

check_list_kv_pair(init_kv(), keys)
check_list_kv_pair(init_kv(), str_keys)

@with_seed()
def test_pushpull_single_kv_pair():
"""aggregate value on muliple devices"""
def check_aggregator(kv, key, key_list):
num_keys = len(key_list)
kv.broadcast(key, mx.nd.zeros(shape), out=mx.nd.empty(shape))
kv.broadcast(key_list, [mx.nd.zeros(shape)] * num_keys,
out=[mx.nd.empty(shape)] * num_keys)
# devices
num_devs = 4
devs = [mx.Context('cpu', i) for i in range(num_devs)]

# single
vals = [mx.nd.ones(shape, d) for d in devs]
outs = [mx.nd.empty(shape, d) for d in devs]

kv.pushpull(key, vals, out=outs)
for out in outs:
check_diff_to_scalar(out, num_devs)

# inplace
kv.pushpull(key, vals)
for val in vals:
check_diff_to_scalar(val, num_devs)

# list
vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * num_keys
outs = [[mx.nd.empty(shape, d) for d in devs]] * num_keys
kv.pushpull(key_list, vals, out=outs)
for out in outs:
for o in out:
check_diff_to_scalar(o, num_devs * 2.0)

# inplace
kv.pushpull(key_list, vals)
for val in vals:
for v in val:
check_diff_to_scalar(v, num_devs * 2.0)

check_aggregator(init_kv(), 3, keys)
check_aggregator(init_kv(), 'a', str_keys)

@with_seed()
def test_pushpull_list_kv_pair():
"""aggregate value on muliple devices"""
def check_aggregator(kv, key, key_list):
num_keys = len(key_list)
kv.broadcast(key, mx.nd.zeros(shape), out=mx.nd.empty(shape))
kv.broadcast(key_list, [mx.nd.zeros(shape)] * num_keys,
out=[mx.nd.empty(shape)] * num_keys)
# devices
num_devs = 4
devs = [mx.Context('cpu', i) for i in range(num_devs)]

# single
vals = [mx.nd.ones(shape, d) for d in devs]
outs = [mx.nd.empty(shape, d) for d in devs]

kv.pushpull(key, vals, out=outs)
for out in outs:
check_diff_to_scalar(out, num_devs)

# list
vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * num_keys
outs = [[mx.nd.empty(shape, d) for d in devs]] * num_keys
kv.pushpull(key_list, vals, out=outs)
for out in outs:
for o in out:
check_diff_to_scalar(o, num_devs * 2.0)

check_aggregator(init_kv(), 3, keys)
check_aggregator(init_kv(), 'a', str_keys)

@with_seed()
def test_get_type_device():
kvtype = 'device'
kv = mx.kv.create(kvtype)
assert kv.type == kvtype

if __name__ == '__main__':
import nose
nose.runmodule()