Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: refactor data stat #3285

Merged
merged 42 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f352a67
checkpoint
njzjz Feb 16, 2024
b2b48b2
Merge branch 'devel' into env-mat-stat
njzjz Feb 16, 2024
1c3e2bb
record sum
njzjz Feb 16, 2024
d6bf4ab
compute_std
njzjz Feb 16, 2024
4f029b0
protection
njzjz Feb 16, 2024
afd6d6a
save stat
njzjz Feb 16, 2024
2f67aac
get std
njzjz Feb 16, 2024
89d413c
compute avg
njzjz Feb 16, 2024
8b6e24e
sea looks good
njzjz Feb 16, 2024
4736ae5
rm init_desc_stat
njzjz Feb 16, 2024
6cb1275
rm get_stat_name
njzjz Feb 16, 2024
b99d330
rewrite compute_input_stats
njzjz Feb 16, 2024
da1e72d
hybrid
njzjz Feb 16, 2024
959583a
fix hash
njzjz Feb 16, 2024
524366d
se atten
njzjz Feb 16, 2024
87f0d85
to make it work
njzjz Feb 16, 2024
2e815ea
compute_or_load_stat
njzjz Feb 16, 2024
ed34d59
init
njzjz Feb 16, 2024
d04d16c
fix shape
njzjz Feb 16, 2024
2580f8e
fix concat
njzjz Feb 16, 2024
2771b1e
make it work
njzjz Feb 16, 2024
3eb5577
rm save_stats and load_stats
njzjz Feb 16, 2024
3ad5484
assert_allclose
njzjz Feb 17, 2024
2b9bbd8
fix shape
njzjz Feb 17, 2024
7d40b9f
add env mat type
njzjz Feb 17, 2024
a55e21f
remove print
njzjz Feb 17, 2024
c34622d
merge methods
njzjz Feb 17, 2024
af0711c
merge
njzjz Feb 17, 2024
37e9b28
clean
njzjz Feb 17, 2024
2c63335
clean
njzjz Feb 17, 2024
d343c32
fix load stats
njzjz Feb 17, 2024
6d8955a
rm process_stat_path
njzjz Feb 17, 2024
582451e
fix py38 compatibility
njzjz Feb 17, 2024
1c1c1a5
fix typo
njzjz Feb 17, 2024
b47014b
rm unused compute_std
njzjz Feb 17, 2024
ef5a92e
bugfix
njzjz Feb 17, 2024
367f472
Merge branch 'devel' into env-mat-stat
njzjz Feb 17, 2024
a8aef18
make test work
njzjz Feb 18, 2024
6a83465
add type_map to stat_file_path
njzjz Feb 18, 2024
ba688a9
update share_params
njzjz Feb 18, 2024
7c9a66e
base_env starts from base_class.stats
njzjz Feb 18, 2024
2ad6990
fix py38 compatibility
njzjz Feb 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from deepmd.pt.utils.utils import (
dict_to_device,
)
from deepmd.utils.path import (

Check warning on line 25 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L25

Added line #L25 was not covered by tests
DPPath,
)

Expand Down Expand Up @@ -179,13 +179,17 @@
stat_file_path
The dictionary of paths to the statistics files.
"""
if stat_file_path is not None and self.type_map is not None:

Check warning on line 182 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L182

Added line #L182 was not covered by tests
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)
for data_sys in sampled:
dict_to_device(data_sys)
if sampled is None:
sampled = []
self.descriptor.compute_input_stats(sampled, stat_file_path)

Check warning on line 190 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L185-L190

Added lines #L185 - L190 were not covered by tests
if self.fitting_net is not None:
self.fitting_net.compute_output_stats(sampled, stat_file_path)

Check warning on line 192 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L192

Added line #L192 was not covered by tests

@torch.jit.export
def get_dim_fparam(self) -> int:
Expand Down
41 changes: 20 additions & 21 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
from deepmd.pt.model.network.network import (
TypeEmbedNet,
)
from deepmd.pt.utils import (

Check warning on line 18 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L18

Added line #L18 was not covered by tests
env,
)
from deepmd.pt.utils.env_mat_stat import (

Check warning on line 21 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L21

Added line #L21 was not covered by tests
EnvMatStatSeA,
)
from deepmd.pt.utils.plugin import (
Plugin,
)
from deepmd.utils.env_mat_stat import (

Check warning on line 27 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L27

Added line #L27 was not covered by tests
StatItem,
)
from deepmd.utils.path import (

Check warning on line 30 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L30

Added line #L30 was not covered by tests
DPPath,
)

Expand Down Expand Up @@ -171,10 +180,14 @@
"""Returns the embedding dimension."""
pass

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):

Check warning on line 183 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L183

Added line #L183 was not covered by tests
"""Update mean and stddev for DescriptorBlock elements."""
raise NotImplementedError

def get_stats(self) -> dict[str, StatItem]:

Check warning on line 187 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L187

Added line #L187 was not covered by tests
"""Get the statistics of the descriptor."""
raise NotImplementedError

def share_params(self, base_class, shared_level, resume=False):
assert (
self.__class__ == base_class.__class__
Expand All @@ -183,27 +196,13 @@
# link buffers
if hasattr(self, "mean") and not resume:
# in case of change params during resume
sumr_base, suma_base, sumn_base, sumr2_base, suma2_base = (
base_class.sumr,
base_class.suma,
base_class.sumn,
base_class.sumr2,
base_class.suma2,
)
sumr, suma, sumn, sumr2, suma2 = (
self.sumr,
self.suma,
self.sumn,
self.sumr2,
self.suma2,
)
stat_dict = {
"sumr": sumr_base + sumr,
"suma": suma_base + suma,
"sumn": sumn_base + sumn,
"sumr2": sumr2_base + sumr2,
"suma2": suma2_base + suma2,
}
base_env = EnvMatStatSeA(base_class)
for kk in base_class.get_stats():
base_env.stats[kk] += self.get_stats()[kk]
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
mean, stddev = base_env()
if not base_class.set_davg_zero:
base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE))
base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))

Check warning on line 205 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L199-L205

Added lines #L199 - L205 were not covered by tests
self.mean = base_class.mean
self.stddev = base_class.stddev
# self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model
Expand Down
13 changes: 13 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env_mat_stat import (

Check warning on line 21 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L21

Added line #L21 was not covered by tests
EnvMatStatSeA,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
)
from deepmd.utils.env_mat_stat import (

Check warning on line 27 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L27

Added line #L27 was not covered by tests
StatItem,
)
from deepmd.utils.path import (

Check warning on line 30 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L30

Added line #L30 was not covered by tests
DPPath,
)

Expand Down Expand Up @@ -147,6 +150,7 @@
stddev = torch.ones(sshape, dtype=mydtype, device=mydev)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.stats = None

Check warning on line 153 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L153

Added line #L153 was not covered by tests

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -266,13 +270,22 @@

return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):

Check warning on line 273 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L273

Added line #L273 was not covered by tests
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSeA(self)
if path is not None:
path = path / env_mat_stat.get_hash()
env_mat_stat.load_or_compute_stats(merged, path)
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()

Check warning on line 280 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L275-L280

Added lines #L275 - L280 were not covered by tests
if not self.set_davg_zero:
self.mean.copy_(torch.tensor(mean, device=env.DEVICE))
self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))

def get_stats(self) -> dict[str, StatItem]:

Check warning on line 285 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L285

Added line #L285 was not covered by tests
"""Get the statistics of the descriptor."""
if self.stats is None:
raise RuntimeError(

Check warning on line 288 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L287-L288

Added lines #L287 - L288 were not covered by tests
"The statistics of the descriptor has not been computed."
)
return self.stats

Check warning on line 291 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L291

Added line #L291 was not covered by tests
13 changes: 13 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
PRECISION_DICT,
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.env_mat_stat import (

Check warning on line 24 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L24

Added line #L24 was not covered by tests
EnvMatStatSeA,
)
from deepmd.utils.env_mat_stat import (

Check warning on line 27 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L27

Added line #L27 was not covered by tests
StatItem,
)
from deepmd.utils.path import (

Check warning on line 30 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L30

Added line #L30 was not covered by tests
DPPath,
)

Expand Down Expand Up @@ -116,9 +119,9 @@
"""Returns the output dimension of this descriptor."""
return self.sea.dim_out

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):

Check warning on line 122 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L122

Added line #L122 was not covered by tests
"""Update mean and stddev for descriptor elements."""
return self.sea.compute_input_stats(merged, path)

Check warning on line 124 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L124

Added line #L124 was not covered by tests

@classmethod
def get_data_process_key(cls, config):
Expand Down Expand Up @@ -313,6 +316,7 @@
resnet_dt=self.resnet_dt,
)
self.filter_layers = filter_layers
self.stats = None

Check warning on line 319 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L319

Added line #L319 was not covered by tests

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -374,20 +378,29 @@
else:
raise KeyError(key)

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):

Check warning on line 381 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L381

Added line #L381 was not covered by tests
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSeA(self)
if path is not None:
path = path / env_mat_stat.get_hash()
env_mat_stat.load_or_compute_stats(merged, path)
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
if not self.set_davg_zero:
self.mean.copy_(torch.tensor(mean, device=env.DEVICE))
self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))

Check warning on line 391 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L383-L391

Added lines #L383 - L391 were not covered by tests
if not self.set_davg_zero:
self.mean.copy_(torch.tensor(mean, device=env.DEVICE))
self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))

def get_stats(self) -> dict[str, StatItem]:

Check warning on line 396 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L396

Added line #L396 was not covered by tests
"""Get the statistics of the descriptor."""
if self.stats is None:
raise RuntimeError(

Check warning on line 399 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L398-L399

Added lines #L398 - L399 were not covered by tests
"The statistics of the descriptor has not been computed."
)
return self.stats

Check warning on line 402 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L402

Added line #L402 was not covered by tests

def forward(
self,
nlist: torch.Tensor,
Expand Down
13 changes: 13 additions & 0 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env_mat_stat import (

Check warning on line 23 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L23

Added line #L23 was not covered by tests
EnvMatStatSeA,
)
from deepmd.utils.env_mat_stat import (

Check warning on line 26 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L26

Added line #L26 was not covered by tests
StatItem,
)
from deepmd.utils.path import (

Check warning on line 29 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L29

Added line #L29 was not covered by tests
DPPath,
)

Expand Down Expand Up @@ -137,6 +140,7 @@
)
filter_layers.append(one)
self.filter_layers = torch.nn.ModuleList(filter_layers)
self.stats = None

Check warning on line 143 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L143

Added line #L143 was not covered by tests

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -187,17 +191,26 @@
"""Returns the output dimension of embedding."""
return self.get_dim_emb()

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):

Check warning on line 194 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L194

Added line #L194 was not covered by tests
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSeA(self)
if path is not None:
path = path / env_mat_stat.get_hash()
env_mat_stat.load_or_compute_stats(merged, path)
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()

Check warning on line 201 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L196-L201

Added lines #L196 - L201 were not covered by tests
if not self.set_davg_zero:
self.mean.copy_(torch.tensor(mean, device=env.DEVICE))
self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))

def get_stats(self) -> dict[str, StatItem]:

Check warning on line 206 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L206

Added line #L206 was not covered by tests
"""Get the statistics of the descriptor."""
if self.stats is None:
raise RuntimeError(

Check warning on line 209 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L208-L209

Added lines #L208 - L209 were not covered by tests
"The statistics of the descriptor has not been computed."
)
return self.stats

Check warning on line 212 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L212

Added line #L212 was not covered by tests

def forward(
self,
nlist: torch.Tensor,
Expand Down
Loading