Skip to content

Commit

Permalink
windows doesn't like opening tempfile twice
Browse files Browse the repository at this point in the history
  • Loading branch information
andrrizzi committed Apr 17, 2023
1 parent 33fa009 commit 54ef4b9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
12 changes: 9 additions & 3 deletions mlcvs/tests/test_cvs_multitask_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# GLOBAL IMPORTS
# =============================================================================

import os
import tempfile

import pytest
Expand Down Expand Up @@ -283,8 +284,13 @@ def test_multitask_training(main_cv_name, weights, auxiliary_loss_names, loss_co
assert x_hat.shape == (x.shape[0], N_CVS)

# Do round-trip through torchscript.
with tempfile.NamedTemporaryFile('wb', suffix='.ptc') as f:
multi_cv.to_torchscript(file_path=f.name, method='trace')
multi_cv_loaded = torch.jit.load(f.name)
# This try-finally clause is a workaround for windows not allowing opening temp files twice.
try:
tmp_file = tempfile.NamedTemporaryFile('wb', suffix='.ptc', delete=False)
tmp_file.close()
multi_cv.to_torchscript(file_path=tmp_file.name, method='trace')
multi_cv_loaded = torch.jit.load(tmp_file.name)
finally:
os.unlink(tmp_file.name)
x_hat2 = multi_cv_loaded(x)
assert torch.allclose(x_hat, x_hat2)
12 changes: 9 additions & 3 deletions mlcvs/tests/test_cvs_unsupervised_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# GLOBAL IMPORTS
# =============================================================================

import os
import tempfile

import pytest
Expand Down Expand Up @@ -63,8 +64,13 @@ def test_vae_cv_training(weights):
assert x_hat.shape == (batch_size, n_cvs)

# Test export to torchscript.
with tempfile.NamedTemporaryFile('wb', suffix='.ptc') as f:
model.to_torchscript(file_path=f.name, method='trace')
model_loaded = torch.jit.load(f.name)
# This try-finally clause is a workaround for windows not allowing opening temp files twice.
try:
tmp_file = tempfile.NamedTemporaryFile('wb', suffix='.ptc', delete=False)
tmp_file.close()
model.to_torchscript(file_path=tmp_file.name, method='trace')
model_loaded = torch.jit.load(tmp_file.name)
finally:
os.unlink(tmp_file.name)
x_hat2 = model_loaded(x)
assert torch.allclose(x_hat, x_hat2)

0 comments on commit 54ef4b9

Please sign in to comment.