-
Notifications
You must be signed in to change notification settings - Fork 6.8k
mx.sym.constant #6087
Comments
yeah that's something I wanted to do but didn't have time for. |
I'm interested in this. @piiswrong |
my friend has implemented a simple version of constant initializer @panzheyi |
I implemented a possible solution. import mxnet as mx
import numpy as np
@mx.init.register
class MyConstant(mx.init.Initializer):
def __init__(self, value):
super(MyConstant, self).__init__(value=value)
self.value = value
def _init_weight(self, _, arr):
arr[:] = mx.nd.array(self.value)
batch_size = 10
const_arr = np.ones((5, 5)).tolist()
a = mx.sym.Variable('a', shape = (5, 5), init = MyConstant(value = const_arr))
a = mx.sym.BlockGrad(a) # now variable a is a constant
data = mx.sym.Variable('data')
loss = mx.sym.MakeLoss(mx.sym.broadcast_add(a, data))
mod = mx.mod.Module(loss, data_names = ['data'], label_names = [])
mod.bind(data_shapes = [('data', (batch_size, 5, 5)),])
mod.init_params(initializer = mx.init.Uniform())
mod.init_optimizer(optimizer = 'sgd', optimizer_params = (('learning_rate', 0.005),))
a = mx.nd.ones((5, 5))
data = np.ones((1000, 5, 5))
dataiter = mx.io.NDArrayIter(data = {'data': data}, batch_size = batch_size)
dataiter.reset()
for batch_id, databatch in enumerate(dataiter):
mod.forward_backward(databatch)
mod.update()
print mod.get_outputs()[0].asnumpy() Variable a is a constant which is initialized by a numpy array. |
+1 |
+1 good point |
I have been searching for this solution for hours. Finally got one working :) |
any official implementation? |
+1 This is an essential part of network creation and would be very helpful to include natively. |
I've looked everywhere for a proper solution, didn't like the idea of using the special constant initializer class since there should not be a difference between scalar constant to any shape tensor. Finally what worked for me is the following: |
@piiswrong Is there a plan to implement this simple feature? |
@piiswrong a short example will be very helpful. |
For symbol constructions, there is var, zeros, and ones. Can we also get a constant one that receives a list/numpy/mx.ndarray and makes a constant symbol out of it? Working it in via variables is a bit cumbersome, particularly because the Constant initialization doesn't work for any variable with a name not ending with the string 'weight'.
The text was updated successfully, but these errors were encountered: