diff --git a/examples/scsn_polarity_lstm_training.py b/examples/scsn_polarity_lstm_training.py index 485c41d..d943a10 100644 --- a/examples/scsn_polarity_lstm_training.py +++ b/examples/scsn_polarity_lstm_training.py @@ -33,7 +33,7 @@ # Prepare dataset dsets.set_memory_limit(10 * 1024 ** 3) # first number is GB # dset = dsets.Wenchuan(path='/home/qszhai/temp_project/deep_learning_course_project/cpic', download=False,sample_transform=waveform_transform) - dset = dsets.Wenchuan(path='/home/qszhai/temp_project/deep_learning_course_project/first_motion_polarity/scsn_data/train_npy', download=False, sample_transform=waveform_transform) + dset = dsets.SCSN_polarity(path='/home/qszhai/temp_project/deep_learning_course_project/first_motion_polarity/scsn_data/train_npy', download=False, sample_transform=waveform_transform) # Split datasets into training and validation train_length = int(len(dset) * 0.8) @@ -45,7 +45,8 @@ val_loader = DataLoader(val_set, batch_size=10000, shuffle=False, num_workers=4) # Prepare trainer - #trainer = Trainer(cpic(), CrossEntropyLoss(), lr=0.1) + # trainer = Trainer(cpic(), CrossEntropyLoss(), lr=0.1) + # note: please use only 1 gpu to run LSTM, https://github.com/pytorch/pytorch/issues/21108 model_conf = {"hidden_size": 64} plt = polarity(**model_conf) trainer = Trainer(plt, CrossEntropyLoss(), lr=0.001)