Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The network will not converge when backend is 'cupy' #106

Closed
fangwei123456 opened this issue Sep 11, 2021 · 3 comments
Closed

The network will not converge when backend is 'cupy' #106

fangwei123456 opened this issue Sep 11, 2021 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@fangwei123456
Copy link
Owner

LIFWrapper(neuron.MultiStepLIFNode(tau=10.0 / 7, surrogate_function=surrogate.Sigmoid(alpha=10.), backend='torch')),

@fangwei123456 fangwei123456 added the bug Something isn't working label Sep 11, 2021
@fangwei123456 fangwei123456 self-assigned this Sep 11, 2021
@fangwei123456
Copy link
Owner Author

fangwei123456 commented Sep 15, 2021

        for audios, labels in tqdm(train_dataloader, disable=True):
            audios = audios.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            optimizer.zero_grad()

            for m in net.modules():
                if isinstance(m, neuron.MultiStepLIFNode):
                    m.backend = 'cupy'


            out_spikes_counter_frequency = net(audios)

            loss = criterion(out_spikes_counter_frequency, labels)

            loss.backward()

            gd_cupy = []
            for m in net.parameters():
                gd_cupy.append(m.grad.clone())

            reset_net(net)
            optimizer.zero_grad()

            for m in net.modules():
                if isinstance(m, neuron.MultiStepLIFNode):
                    m.backend = 'torch'

            out_spikes_counter_frequency = net(audios)

            loss = criterion(out_spikes_counter_frequency, labels)


            loss.backward()

            gd_torch = []
            for m in net.parameters():
                gd_torch.append(m.grad.clone())

            reset_net(net)

            for i in range(gd_cupy.__len__()):
                print(i, gd_cupy[i].shape, (gd_cupy[i] - gd_torch[i]).abs().max().item())


            # Rate-based output decoding
            correct_rate = (out_spikes_counter_frequency.argmax(
                dim=1) == labels).float().mean().item()
            print(net.train_times, 'acc=', correct_rate)

            net.train_times += 1

Epoch 0
0 torch.Size([64, 1, 4, 3]) 2.525382125639908e-09
1 torch.Size([64, 64, 4, 3]) 1.087092904583642e-08
2 torch.Size([64, 64, 4, 3]) 0.0
3 torch.Size([12, 2560]) 0.0
4 torch.Size([12]) 0.0
0 acc= 0.09375
0 torch.Size([64, 1, 4, 3]) 3.250982372193789e-09
1 torch.Size([64, 64, 4, 3]) 7.123997392000092e-09
2 torch.Size([64, 64, 4, 3]) 0.0
3 torch.Size([12, 2560]) 0.0
4 torch.Size([12]) 0.0
1 acc= 0.078125
0 torch.Size([64, 1, 4, 3]) 2.8865985157722207e-09
1 torch.Size([64, 64, 4, 3]) 1.2915431391036236e-08
2 torch.Size([64, 64, 4, 3]) 0.0
3 torch.Size([12, 2560]) 0.0
4 torch.Size([12]) 0.0
2 acc= 0.0625
0 torch.Size([64, 1, 4, 3]) 3.8988936523765005e-09
1 torch.Size([64, 64, 4, 3]) 1.1206062744406609e-08
2 torch.Size([64, 64, 4, 3]) 0.0
3 torch.Size([12, 2560]) 0.0
4 torch.Size([12]) 0.0
3 acc= 0.046875
0 torch.Size([64, 1, 4, 3]) 2.381410402207962e-09
1 torch.Size([64, 64, 4, 3]) 8.244219529274233e-09
2 torch.Size([64, 64, 4, 3]) 0.0
3 torch.Size([12, 2560]) 0.0
4 torch.Size([12]) 0.0
4 acc= 0.0625
0 torch.Size([64, 1, 4, 3]) 2.9487301489439233e-09
1 torch.Size([64, 64, 4, 3]) 1.0766485480928623e-08
2 torch.Size([64, 64, 4, 3]) 0.0
3 torch.Size([12, 2560]) 0.0
4 torch.Size([12]) 0.0
5 acc= 0.078125
0 torch.Size([64, 1, 4, 3]) 2.3316282238283748e-09
1 torch.Size([64, 64, 4, 3]) 9.149445645562082e-09
2 torch.Size([64, 64, 4, 3]) 0.0
3 torch.Size([12, 2560]) 0.0
4 torch.Size([12]) 0.0
6 acc= 0.078125
0 torch.Size([64, 1, 4, 3]) 3.791615021953021e-09
1 torch.Size([64, 64, 4, 3]) 1.2025250128999687e-08
2 torch.Size([64, 64, 4, 3]) 0.0
3 torch.Size([12, 2560]) 0.0
4 torch.Size([12]) 0.0
7 acc= 0.078125

@fangwei123456
Copy link
Owner Author

import torch

device = 'cuda:0'
x = torch.rand([2, 3, 4], device=device).transpose(0, 2)
y = torch.zeros_like(x)
args_list = [x, y]
ret_list = []
for item in args_list:
    if isinstance(item, torch.Tensor):
        item = item.contiguous()
        ret_list.append(item.data_ptr())
print(ret_list)

[47284487168, 47284487168]

@fangwei123456
Copy link
Owner Author

fangwei123456 added a commit that referenced this issue Sep 16, 2021
fangwei123456 added a commit that referenced this issue Sep 16, 2021
Revert "a temporary solution for #106"

This reverts commit c824ed0
fangwei123456 added a commit that referenced this issue Sep 16, 2021
fangwei123456 added a commit that referenced this issue Sep 16, 2021
fangwei123456 added a commit that referenced this issue Sep 16, 2021
fangwei123456 added a commit that referenced this issue Sep 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant