diff --git a/source/tests/tf/test_gen_stat_data.py b/source/tests/tf/test_gen_stat_data.py index c3f7e765f7..ebede15fbb 100644 --- a/source/tests/tf/test_gen_stat_data.py +++ b/source/tests/tf/test_gen_stat_data.py @@ -122,7 +122,7 @@ def test_ener_shift(self): data = DeepmdDataSystem(["system_0", "system_1"], 5, 10, 1.0) data.add("energy", 1, must=True) ener_shift0 = data.compute_energy_shift(rcond=1) - all_stat = make_stat_input(data, 4, merge_sys=False) + all_stat = make_stat_input(data, 6, merge_sys=False) descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) fitting = EnerFitting( descrpt.get_ntypes(), @@ -138,7 +138,7 @@ def test_ener_shift_assigned(self): ae0 = dp_random.random() data = DeepmdDataSystem(["system_0"], 5, 10, 1.0) data.add("energy", 1, must=True) - all_stat = make_stat_input(data, 4, merge_sys=False) + all_stat = make_stat_input(data, 6, merge_sys=False) descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) fitting = EnerFitting( descrpt.get_ntypes(),