-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
merge from Qiushi's fork to yewsg/yews: add polarity and focal_mechanism models #18
Merged
Merged
Changes from 250 commits
Commits
Show all changes
275 commits
Select commit
Hold shift + click to select a range
03819dd
Replace MIT license by Apache 2.0
lijunzh 50694af
include license and readme for pip package
lijunzh 8c83e13
add tests for yews.transform.functional
lijunzh f142a87
try travis from torchvision
lijunzh 076dec9
Add travis.ci badge
lijunzh 01f8d19
Fix bug in travis.yml (#1)
lijunzh cc5dc0b
update logo color
lijunzh b43159b
try torchvision’s sphinx setup
lijunzh ff88395
Update logo
lijunzh 11e2cee
Add appveyor.yml for Windows CI (#3)
lijunzh 2e4703c
add appveyor badge
lijunzh 7538b25
add bages for anaconda cloud and pypi
lijunzh b37c503
move badge below the title
lijunzh f74ca1d
remove line between logo and title
lijunzh 8b300c8
Add more test to transforms (#2)
lijunzh 210fa17
Add instructin to install pytorch first
lijunzh c95ad7f
Update conda command
lijunzh 50e0868
Uploading PyTorch builds to lijunzhu channel
lijunzh ae94d69
not import yews in docs
lijunzh 7516965
Create initial docs (#4)
lijunzh 6c86d34
use www subdomain for docs
lijunzh af97b15
Squashed commit of the following:
lijunzh f10364b
Squashed commit of the following:
lijunzh b2051a9
Add transform to convert label to int
lijunzh 35a559d
Improve conda installation
lijunzh 1cbd347
Squashed commit of the following:
lijunzh bf621c6
automate release
lijunzh 4f5bbd3
improve docs.
lijunzh c6edbc5
move docs to a separate repo.
lijunzh 720b2e8
bump version to 0.0.3
lijunzh accc005
update logo url
lijunzh cee62a2
Move metadata from setup.py to setup.cfg
lijunzh 1369b62
improve automation.
lijunzh a13bfb6
yews get version from pkg installation.
lijunzh 0c69039
Use scipy as an extra feature
lijunzh ce1587e
fix a bug to version in yews.__init__
lijunzh 8a7fc86
Remove not-runnable code from coverage report.
lijunzh 27fbb3c
add pre-commit-config
lijunzh 255f523
add changelog.rst
lijunzh 4ae530c
change .coveragerc
lijunzh 51d043c
add staticmethod valid() to check path.
lijunzh cb15b95
refactorizing BaseDataset
lijunzh ae89583
add smoke test via @pytest.mark.smoke
lijunzh 8ffab7d
modify datasets error msgs.
lijunzh 78b9426
add `yews.datasets.utils` with tests covered 100%
lijunzh 74d9c2b
check end of file
lijunzh 4680031
remove redundant __about__.py
lijunzh 3b50da0
update changelog
lijunzh 9f57f5d
add `datasets.wenchuan`
lijunzh 3f41201
fix code issues.
lijunzh edadacd
add test to datasets.wenchuan
lijunzh e3dc9e8
update wenchuan example according to new api
lijunzh 39bc9aa
clean temp files due to broken tests.
lijunzh 9c600ff
try svg for logo image
lijunzh dca3787
change back to gif
lijunzh 5a71e0c
optimize logo and readme layout for mobile
lijunzh 2a47a41
add memory_limit to control loading of .npy file.
lijunzh d7864a6
bump version to 0.0.4
lijunzh 6f20664
fix a typo.
lijunzh e96eb8c
add scipy to host environment
lijunzh 2e5c075
update docs url
lijunzh 943c22e
fix doctring typo
lijunzh be64d7d
sync meta.yaml and setup.cfg for install and test.
lijunzh 63fd322
avoid downloading large file during test
lijunzh a88cc8e
explicitly add allow_pickle for older numpy.
lijunzh a7bafc2
update CHANGELOG.rst
lijunzh 12a0c36
Update installation notes in README.rst
lijunzh 1c9fbf3
Implement original cpic model in the paper.
lijunzh 3a6ad94
create mariana dataset and tools to support it.
lijunzh b1dd25d
bump version to 0.0.5
lijunzh 7f09cd4
wenchuan dataset released to public
lijunzh 66ba8fa
add numpy verion requirement for pathlib usage
lijunzh c89c26c
fix test_datasets.py bug
lijunzh d9df016
add packaged SCSN dataset
lijunzh 73b5849
avoid large downlad on traivs-ci
lijunzh 68f7051
add detection example for mariana dataset
lijunzh 8a076bd
attempt to add OK dataset in the same way as Mariana
lijunzh c8660d6
ignore all model files
lijunzh 330c3fd
Merge branch 'master' into create_OK_dataset
lijunzh 94487f2
Merge branch 'master' into detection
lijunzh 443c094
Merge branch 'master' into new_models_module
lijunzh 0c2983b
Squashed commit of the following:
lijunzh d9d7a48
fix function name
lijunzh 7e4a8c0
fix bias bug
lijunzh 3a57f60
save results after training
lijunzh a6f5058
save model class name
lijunzh efb93c3
Allow save and load checkpoint.
lijunzh 50d4c66
ignore tags file from ctags.
lijunzh 600a18f
add resume function during training
lijunzh 79124f7
add picking
lijunzh f4a774d
remove dimension check for numpy waveform
lijunzh 389e514
rename deploy to cpic
lijunzh 604a798
Merge branch 'detection'
lijunzh 30c8d49
Merge branch 'improve_train_module'
lijunzh b142ddf
add hubconf for torch.hub module
lijunzh e6e1a82
migrate to torch.hub load_url
lijunzh 6933aca
update cpic model
lijunzh 38a2290
update wenchuan example for new models.cpic module
lijunzh f2da734
fix model_device bug
lijunzh d124ff6
test training results save
lijunzh 6ae3138
fix wenchuan example path bug
lijunzh 0ffea0a
fix wenchuan result path name
lijunzh 24aa86e
get filename as staticmethod
lijunzh 446d336
add ok_transfer example
lijunzh 8dae14f
fix bug in sac dataset
lijunzh 20c5900
fix typo in ok transfer
lijunzh fb1032c
fix typo in ok transfer
lijunzh d9d2c31
add loader to ok dataset
lijunzh e94fa4a
fix glob bug for ok transfer
lijunzh 0fa4cec
change path str for obspy read
lijunzh 4f2c566
convert path after creating label
lijunzh 232713c
try to fix appveyor
lijunzh c966ff1
install obspy for appveyor
lijunzh f3a885b
do not download large file during testing
lijunzh 2668ee5
use tar instead of tar.bz2 for packaged datasets
lijunzh aa5e3c4
use model intead of model_gen for trainer
lijunzh 9473d89
show accuracy at the end of each epoch
lijunzh 606bdb9
add cpic model pretrianed on wenchuan dataset
lijunzh a52816a
Merge branch 'master' of https://github.com/lijunzh/yews
lijunzh 26d6716
rename example files
lijunzh da0b6a4
save current and best checkpoint during training
lijunzh c1ddf50
training from initial model
lijunzh 1faf882
add scipy as a mandatory dependency
lijunzh 531049a
new deployment example for Mw 7.5 earthquake in southern pacific
lijunzh 6705a5f
start a doucment for rbp installation steps.
lijunzh 33c8296
add miniconda and build pytorch from source
lijunzh e890b60
update environmental variable
lijunzh 6373248
disable qnnpack
lijunzh eb4cf4b
fix a bug in applying transform during inference
lijunzh fec91d0
bump version before release
lijunzh b386d84
correct typo
lijunzh 03951c7
use tensor stack instead array stack
lijunzh ae57971
update example for sp deployment example
lijunzh 85a6667
add raspberry pi files
lijunzh 51a8486
marian deployment example
lijunzh e728649
Merge branch 'create_OK_dataset'
lijunzh 3195c37
update cpic with a simplified model
lijunzh 87a6010
update rbp example with plots
lijunzh c07dc3b
Merge branch 'master' of https://github.com/lijunzh/yews
lijunzh 9e3c07e
add batch_size for deployment
lijunzh c5c9eea
update rbp example
lijunzh 0f1ca9c
update rbp example with simplified model
lijunzh 8f2b4dd
update rbp example with simplified model
lijunzh d0e55ce
update rbp example
lijunzh e449efb
Correct a typo
zjzzqs a941c70
Merge pull request #5 from lijunzh/zjzzqs-patch
zjzzqs 556c655
Making some utils available outside classes
lijunzh 1142a3f
add example for preparing dataset from file names
lijunzh 522edd6
move old example to experimental
lijunzh 27bc438
reorganize dataset pacakge
lijunzh bf0e74f
correct typo
lijunzh 1c8a2a5
skipped broken waveforms.
lijunzh df0fb46
fix bug
lijunzh 0414081
fix test for new dataset package
lijunzh 2c25823
remove unused varialbe.
lijunzh 88e41a0
prepare waveform by groups
lijunzh 7dc27f8
avoid skipping the entire group for one invalid phase.
lijunzh c1fe60c
merge groups of npys into one
lijunzh d040eb7
add notes for merging large npy arrays.
lijunzh dd6b5bc
make some object available on the top level of yews package.
lijunzh c31b0bd
docs fiex
lijunzh 9d9e09b
fix bugs in example
lijunzh f74ba10
training example
lijunzh 323d4f6
increase batch size for faster training and validation
lijunzh 9e326aa
increase memory limit to load the entire dataset in memory
lijunzh d9e6bbb
run logner training
lijunzh 0a0c8d8
Update package structure
lijunzh 6e48a1b
Update conda install pytorch command for testing.
lijunzh 3829323
Local test skip downloading large files.
lijunzh f8c43a8
Update mmap store code for npy.
lijunzh d080285
Raise exception when file not exists.
lijunzh 6be71ee
Put a soft link to data inside example directory.
lijunzh f605d21
Temporarily disable tqdm in exporting data.
lijunzh da89a13
Update CHANGELOG
lijunzh c48efed
Fix meta.yaml depdendency
lijunzh 33c09bf
Improve anaconda build process.
lijunzh cdd087f
Use softlink to data path.
lijunzh 849ce4d
Update URLs to package datasets.
lijunzh c501e55
change Wenchuan data url from gt to dropbox
zjzzqs 84eea47
add packaged_datasets SCSN_polarity
zjzzqs f68d1a2
add packaged_datasets SCSN_polarity
zjzzqs 5ac23cf
change MEMORY_LIMIT from 2g to 10g
zjzzqs 5606ea7
add import polarity.py
zjzzqs 43d51f8
add polarity.py
zjzzqs a7ed9eb
add import numpy to polarity.py
zjzzqs 41eb856
added comment for pull request test
ChujieChen 4f8076a
Merge pull request #1 from ChujieChen/master
zjzzqs e59a67e
delete commit examples in polarity.py
zjzzqs 1cad706
add wenchuan cpic example
zjzzqs 7977424
add scsn polarity training example
zjzzqs 6693acc
Merge pull request #1 from zjzzqs/master
ChujieChen 65aa625
delete the note of 2d, will see it in the focal_mechanism.py
zjzzqs cfe3c9f
primitive LSTM model added in polarity.py
ChujieChen 859f289
add Taiwan_focal_mechanism dataset
zjzzqs 765e0b2
add Taiwan_focal_mechanism dataset
zjzzqs 6f8e7cb
add focal_mechanism model
zjzzqs b2b5eb1
add focal_mechanism model
zjzzqs dca9078
rename scsn.training.py to scsn_polarity_cnn.training.py
zjzzqs de770a5
rename scsn.training.py to scsn_polarity_cnn.training.py
zjzzqs edea810
add taiwan_focal_mechanism.training.py to example
zjzzqs 7179212
change the batch_size and learning rate of this example
zjzzqs c04cba6
add VGG style fm_v2 into models/focal_mechanism.py
zjzzqs f879d8b
modified VGG style fm_v2, use dropout(0.1) after each maxpool
zjzzqs 8147ae7
delete unknow label, add vgg style model, remove the last 2 cnn layers
zjzzqs 6bb404a
delete unknow label, add vgg style model, remove the last 2 cnn layers
zjzzqs 3e9922e
add vgg style model for grad-cam, remove the last 2 cnn layers, stop …
zjzzqs 57ea196
add a backup line of using AdamW instead of Adam
zjzzqs 4952abf
working LSTM (bidirectional untested)
ChujieChen 0ccc840
Merge pull request #2 from ChujieChen/develop
ChujieChen 0af73fb
finished LSTM for polarity
ChujieChen aece550
Merge pull request #3 from ChujieChen/develop
ChujieChen dba1469
Merge branch 'master' into master
zjzzqs 1c829be
Merge pull request #2 from ChujieChen/master
zjzzqs 1171936
Merge pull request #4 from zjzzqs/master
ChujieChen 28d0c1f
added example for polarity LSTM
ChujieChen f970c46
fix the indent
zjzzqs debcef5
Merge pull request #3 from ChujieChen/master
zjzzqs c41b26f
change the wenchuan example file name
zjzzqs d394f59
fix indent of polarity.py again
zjzzqs 80e49d4
add a note: please use only 1 gpu to run LSTM, https://github.com/pyt…
zjzzqs 3949d1b
add a note: please use only 1 gpu to run LSTM, https://github.com/pyt…
zjzzqs 1df3fd9
fix the dsets name in the example
zjzzqs f8dc827
add WeightedRandomSampler to balance the numbers of different labels …
zjzzqs 7cea369
add Taiwan20092010 of cpic into packaged_datasets.py and __init__.py
zjzzqs a0a98d9
add Taiwan20092010 of cpic into packaged_datasets.py and __init__.py
zjzzqs fc22fbc
add example for cpic: Taiwan20092010
zjzzqs 4cb6c4d
add vgg style model cpicv3, stop at 4 for grad-cam
zjzzqs 1e464a8
add vgg style model cpicv3, stop at 4 for grad-cam
zjzzqs 0b28b77
vgg style model FmV2 stop at 8*8
zjzzqs a8e95ee
vgg style model FmV2 stop at 8*8
zjzzqs 6f4014b
update cpic.py and wenchuan_cpic.training.py based on the test of gra…
zjzzqs d2b0be3
forget why, so just add a comment #wav = wav.astype(float) into src/y…
zjzzqs 3935031
add RemoveMean RemoveTrend Taper BandpassFilter into src/yews/transfo…
zjzzqs 773bae9
add polarity_cnn_lstm from Zijian Li
zjzzqs 2c717e7
input 600->300
zjzzqs 9120bd9
rm data in example
zjzzqs 4b2ae51
fix super
zjzzqs 1047605
update example
zjzzqs b2b40f1
update cnn_lstm
zjzzqs 898d3f5
add a line for LSTM which can only use one gpu
zjzzqs e21445f
need to be updated, how to read the pretrained model
zjzzqs 35d670d
Resolve merge conflict
zjzzqs 20a4124
Resolve merge conflict
zjzzqs 2f27820
delete train.py.bak
zjzzqs 6d84a52
fix bug <<<<<<< ======= >>>>>>>
zjzzqs 2297e5f
recover some image and target
zjzzqs 8805b78
nothing important
zjzzqs e019e7c
xxxx to null link
zjzzqs 9ed1303
delete the commit of using 1 gpu, in the future, use: device = torch.…
zjzzqs 2dadda1
remove RemoveMean, change Taper and BandpassFilter
zjzzqs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,15 +16,10 @@ applying deep learning techniques on seismic waveform data. | |
|
||
|
||
|
||
.. image:: https://travis-ci.org/yewsg/yews.svg?branch=master | ||
:target: https://travis-ci.org/yewsg/yews | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we want to remove the badges here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Resolve. Please check it again. Thanks. |
||
|
||
.. image:: https://ci.appveyor.com/api/projects/status/32r7s2skrgm9ubva?svg=true | ||
:target: https://ci.appveyor.com/project/lijunzh/yews | ||
|
||
.. image:: https://codecov.io/gh/yewsg/yews/branch/master/graph/badge.svg | ||
:target: https://codecov.io/gh/yewsg/yews | ||
|
||
.. image:: https://anaconda.org/lijunzhu/yews/badges/version.svg | ||
:target: https://anaconda.org/lijunzhu/yews | ||
|
||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import datetime | ||
import torch | ||
from torch.nn import CrossEntropyLoss | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data import random_split | ||
|
||
import yews.datasets as dsets | ||
import yews.transforms as transforms | ||
from yews.train import Trainer | ||
|
||
from yews.models import polarity_cnn | ||
model=polarity_cnn | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
print("Now: start : " + str(datetime.datetime.now())) | ||
|
||
# Preprocessing | ||
waveform_transform = transforms.Compose([ | ||
#transforms.ZeroMean(), | ||
#transforms.SoftClip(1e-4), | ||
transforms.ToTensor(), | ||
]) | ||
|
||
# Prepare dataset | ||
dsets.set_memory_limit(10 * 1024 ** 3) # first number is GB | ||
# dset = dsets.Wenchuan(path='/home/qszhai/temp_project/deep_learning_course_project/cpic', download=False,sample_transform=waveform_transform) | ||
dset = dsets.SCSN(path='/data6/scsn/polarity/train_npy', download=False, sample_transform=waveform_transform) | ||
|
||
# Split datasets into training and validation | ||
train_length = int(len(dset) * 0.8) | ||
val_length = len(dset) - train_length | ||
train_set, val_set = random_split(dset, [train_length, val_length]) | ||
|
||
# Prepare dataloaders | ||
train_loader = DataLoader(train_set, batch_size=5000, shuffle=True, num_workers=4) | ||
val_loader = DataLoader(val_set, batch_size=10000, shuffle=False, num_workers=4) | ||
|
||
# Prepare trainer | ||
trainer = Trainer(model(), CrossEntropyLoss(), lr=0.01) | ||
|
||
# Train model over training dataset | ||
trainer.train(train_loader, val_loader, epochs=100, print_freq=100) | ||
#resume='checkpoint_best.pth.tar') | ||
|
||
# Save training results to disk | ||
trainer.results(path='scsn_polarity_results.pth.tar') | ||
|
||
# Validate saved model | ||
results = torch.load('scsn_polarity_results.pth.tar') | ||
model = model() | ||
model.load_state_dict(results['model']) | ||
trainer = Trainer(model, CrossEntropyLoss(), lr=0.1) | ||
trainer.validate(val_loader, print_freq=100) | ||
|
||
print("Now: end : " + str(datetime.datetime.now())) | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
myfontsize1=14 | ||
myfontsize2=18 | ||
myfontsize3=24 | ||
|
||
results = torch.load('scsn_polarity_results.pth.tar') | ||
|
||
fig, axes = plt.subplots(2, 1, num=0, figsize=(6, 4), sharex=True) | ||
axes[0].plot(results['val_acc'], label='Validation') | ||
axes[0].plot(results['train_acc'], label='Training') | ||
|
||
#axes[1].set_xlabel("Epochs",fontsize=myfontsize2) | ||
axes[0].set_xscale('log') | ||
axes[0].set_xlim([1, 100]) | ||
axes[0].xaxis.set_tick_params(labelsize=myfontsize1) | ||
|
||
axes[0].set_ylabel("Accuracies (%)",fontsize=myfontsize2) | ||
axes[0].set_ylim([0, 100]) | ||
axes[0].set_yticks(np.arange(0, 101, 10)) | ||
axes[0].yaxis.set_tick_params(labelsize=myfontsize1) | ||
|
||
axes[0].grid(True, 'both') | ||
axes[0].legend(loc=4) | ||
|
||
#axes[1].semilogx(results['val_loss'], label='Validation') | ||
#axes[1].semilogx(results['train_loss'], label='Training') | ||
axes[1].plot(results['val_loss'], label='Validation') | ||
axes[1].plot(results['train_loss'], label='Training') | ||
|
||
axes[1].set_xlabel("Epochs",fontsize=myfontsize2) | ||
axes[1].set_xscale('log') | ||
axes[1].set_xlim([1, 100]) | ||
axes[1].xaxis.set_tick_params(labelsize=myfontsize1) | ||
|
||
axes[1].set_ylabel("Losses",fontsize=myfontsize2) | ||
axes[1].set_ylim([0.0, 1.0]) | ||
axes[1].set_yticks(np.arange(0.0,1.01,0.2)) | ||
axes[1].yaxis.set_tick_params(labelsize=myfontsize1) | ||
|
||
axes[1].grid(True, 'both') | ||
axes[1].legend(loc=1) | ||
|
||
fig.tight_layout() | ||
plt.savefig('Accuracies_train_val.pdf') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import datetime | ||
import torch | ||
from torch.nn import CrossEntropyLoss | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data import random_split | ||
|
||
import yews.datasets as dsets | ||
import yews.transforms as transforms | ||
from yews.train import Trainer | ||
|
||
#from yews.models import cpic | ||
#from yews.models import cpic_v1 | ||
#from yews.models import cpic_v2 | ||
#cpic = cpic_v1 | ||
|
||
from yews.models import polarity_v1 | ||
from yews.models import polarity_v2 | ||
from yews.models import polarity_lstm | ||
polarity=polarity_lstm | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
print("Now: start : " + str(datetime.datetime.now())) | ||
|
||
# Preprocessing | ||
waveform_transform = transforms.Compose([ | ||
transforms.ZeroMean(), | ||
#transforms.SoftClip(1e-4), | ||
transforms.ToTensor(), | ||
]) | ||
|
||
# Prepare dataset | ||
dsets.set_memory_limit(10 * 1024 ** 3) # first number is GB | ||
# dset = dsets.Wenchuan(path='/home/qszhai/temp_project/deep_learning_course_project/cpic', download=False,sample_transform=waveform_transform) | ||
dset = dsets.SCSN_polarity(path='/home/qszhai/temp_project/deep_learning_course_project/first_motion_polarity/scsn_data/train_npy', download=False, sample_transform=waveform_transform) | ||
|
||
# Split datasets into training and validation | ||
train_length = int(len(dset) * 0.8) | ||
val_length = len(dset) - train_length | ||
train_set, val_set = random_split(dset, [train_length, val_length]) | ||
|
||
# Prepare dataloaders | ||
train_loader = DataLoader(train_set, batch_size=5000, shuffle=True, num_workers=4) | ||
val_loader = DataLoader(val_set, batch_size=10000, shuffle=False, num_workers=4) | ||
|
||
# Prepare trainer | ||
# trainer = Trainer(cpic(), CrossEntropyLoss(), lr=0.1) | ||
# note: please use only 1 gpu to run LSTM, https://github.com/pytorch/pytorch/issues/21108 | ||
model_conf = {"hidden_size": 64} | ||
plt = polarity(**model_conf) | ||
trainer = Trainer(plt, CrossEntropyLoss(), lr=0.001) | ||
|
||
# Train model over training dataset | ||
trainer.train(train_loader, val_loader, epochs=50, print_freq=100) | ||
#resume='checkpoint_best.pth.tar') | ||
|
||
# Save training results to disk | ||
trainer.results(path='scsn_polarity_results.pth.tar') | ||
|
||
# Validate saved model | ||
results = torch.load('scsn_polarity_results.pth.tar') | ||
#model = cpic() | ||
model = plt | ||
model.load_state_dict(results['model']) | ||
trainer = Trainer(model, CrossEntropyLoss(), lr=0.001) | ||
trainer.validate(val_loader, print_freq=100) | ||
|
||
print("Now: end : " + str(datetime.datetime.now())) | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
myfontsize1=14 | ||
myfontsize2=18 | ||
myfontsize3=24 | ||
|
||
results = torch.load('scsn_polarity_results.pth.tar') | ||
|
||
fig, axes = plt.subplots(2, 1, num=0, figsize=(6, 4), sharex=True) | ||
axes[0].plot(results['val_acc'], label='Validation') | ||
axes[0].plot(results['train_acc'], label='Training') | ||
|
||
#axes[1].set_xlabel("Epochs",fontsize=myfontsize2) | ||
axes[0].set_xscale('log') | ||
axes[0].set_xlim([1, 100]) | ||
axes[0].xaxis.set_tick_params(labelsize=myfontsize1) | ||
|
||
axes[0].set_ylabel("Accuracies (%)",fontsize=myfontsize2) | ||
axes[0].set_ylim([0, 100]) | ||
axes[0].set_yticks(np.arange(0, 101, 10)) | ||
axes[0].yaxis.set_tick_params(labelsize=myfontsize1) | ||
|
||
axes[0].grid(True, 'both') | ||
axes[0].legend(loc=4) | ||
|
||
#axes[1].semilogx(results['val_loss'], label='Validation') | ||
#axes[1].semilogx(results['train_loss'], label='Training') | ||
axes[1].plot(results['val_loss'], label='Validation') | ||
axes[1].plot(results['train_loss'], label='Training') | ||
|
||
axes[1].set_xlabel("Epochs",fontsize=myfontsize2) | ||
axes[1].set_xscale('log') | ||
axes[1].set_xlim([1, 100]) | ||
axes[1].xaxis.set_tick_params(labelsize=myfontsize1) | ||
|
||
axes[1].set_ylabel("Losses",fontsize=myfontsize2) | ||
axes[1].set_ylim([0.0, 1.0]) | ||
axes[1].set_yticks(np.arange(0.0,1.01,0.2)) | ||
axes[1].yaxis.set_tick_params(labelsize=myfontsize1) | ||
|
||
axes[1].grid(True, 'both') | ||
axes[1].legend(loc=1) | ||
|
||
fig.tight_layout() | ||
plt.savefig('Accuracies_train_val.pdf') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import datetime | ||
import torch | ||
from torch.nn import CrossEntropyLoss | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data import random_split | ||
|
||
import yews.datasets as dsets | ||
import yews.transforms as transforms | ||
from yews.train import Trainer | ||
|
||
#from yews.models import cpic | ||
from yews.models import cpic_v1 | ||
from yews.models import cpic_v2 | ||
from yews.models import cpic_v3 | ||
cpic = cpic_v3 | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
print("Now: start : " + str(datetime.datetime.now())) | ||
|
||
# Preprocessing | ||
waveform_transform = transforms.Compose([ | ||
transforms.ZeroMean(), | ||
#transforms.RemoveTrend(), | ||
#transforms.RemoveMean(), | ||
#transforms.Taper(), | ||
#transforms.BandpassFilter(), | ||
#transforms.SoftClip(2e-3), | ||
#1e-2=1/100 100=1% max | ||
#2e-3=4/2048 hist: max = 2048 | ||
#import numpy as np;import matplotlib.pyplot as plt;samples=np.load("samples.npy",mmap_mode='r'); | ||
#targets=np.load("targets.npy");target.shape | ||
#plt.hist(samples[0:100000,0,:].flatten(), bins=100); plt.ylim([0.1,1.5e8]);plt.show() | ||
transforms.ToTensor(), | ||
]) | ||
|
||
# Prepare dataset | ||
dsets.set_memory_limit(10 * 1024 ** 3) # first number is GB | ||
dset = dsets.Taiwan20092010(path='/home/qszhai/temp_project/deep_learning_course_project/cpic/Taiwan20092010', download=False, sample_transform=waveform_transform) | ||
|
||
# Split datasets into training and validation | ||
train_length = int(len(dset) * 0.8) | ||
val_length = len(dset) - train_length | ||
train_set, val_set = random_split(dset, [train_length, val_length]) | ||
|
||
# Prepare dataloaders | ||
train_loader = DataLoader(train_set, batch_size=2000, shuffle=True, num_workers=4) | ||
# train_set: bastch_size = targets.shape / 500 | ||
val_loader = DataLoader(val_set, batch_size=4000, shuffle=False, num_workers=4) | ||
# train_set: bastch_size : larger is better if the GPU memory is enough. | ||
# num_workers = number of cpu core, but limited by the disk speed. so 8 is good. | ||
|
||
# Prepare trainer | ||
trainer = Trainer(cpic(), CrossEntropyLoss(), lr=0.1) | ||
|
||
# Train model over training dataset | ||
trainer.train(train_loader, val_loader, epochs=300, print_freq=100) | ||
#resume='checkpoint_best.pth.tar') | ||
|
||
# Save training results to disk | ||
trainer.results(path='Taiwan20092010_results.pth.tar') | ||
|
||
# Validate saved model | ||
results = torch.load('Taiwan20092010_results.pth.tar') | ||
model = cpic() | ||
model.load_state_dict(results['model']) | ||
trainer = Trainer(model, CrossEntropyLoss(), lr=0.1) | ||
trainer.validate(val_loader, print_freq=100) | ||
|
||
print("Now: end : " + str(datetime.datetime.now())) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a text generated by git when solving conflicts. Needs to be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolve. Please check it again. Thanks.