Skip to content

Commit

Permalink
refactor: extract np.asarray into reshape utility
Browse files Browse the repository at this point in the history
  • Loading branch information
mrossinek committed Sep 1, 2022
1 parent 81d2534 commit e048860
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions qiskit_nature/second_q/formats/qcschema_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,33 +73,32 @@ def qcschema_to_problem(
basis_transform: ElectronicBasisTransform | None = None

def reshape_2(arr, dim, dim_2=None):
return arr.reshape((dim, dim_2 if dim_2 is not None else dim))
return np.asarray(arr).reshape((dim, dim_2 if dim_2 is not None else dim))

def reshape_4(arr, dim):
return arr.reshape((dim,) * 4)
return np.asarray(arr).reshape((dim,) * 4)

if qcschema.wavefunction.scf_fock_a is not None:
# TODO: deal with this properly
hij = np.asarray(qcschema.wavefunction.scf_fock_a)
nao = int(np.sqrt(len(hij)))
hij = reshape_2(hij, nao)
nao = int(np.sqrt(len(qcschema.wavefunction.scf_fock_a)))
hij = reshape_2(qcschema.wavefunction.scf_fock_a, nao)

if qcschema.wavefunction.scf_fock_b is not None:
hij_b = reshape_2(np.asarray(qcschema.wavefunction.scf_fock_b), nao)
hij_b = reshape_2(qcschema.wavefunction.scf_fock_b, nao)

if hij is not None:
one_body_ao = OneBodyElectronicIntegrals(ElectronicBasis.AO, (hij, hij_b))
ints.append(one_body_ao)

if qcschema.wavefunction.scf_eri is not None:
eri = reshape_4(np.asarray(qcschema.wavefunction.scf_eri), nao)
eri = reshape_4(qcschema.wavefunction.scf_eri, nao)
two_body_ao = TwoBodyElectronicIntegrals(ElectronicBasis.AO, (eri, None, None, None))
ints.append(two_body_ao)

if qcschema.wavefunction.scf_orbitals_a is not None:
coeff_a = reshape_2(np.asarray(qcschema.wavefunction.scf_orbitals_a), nao, nmo)
coeff_a = reshape_2(qcschema.wavefunction.scf_orbitals_a, nao, nmo)
if qcschema.wavefunction.scf_orbitals_b is not None:
coeff_b = reshape_2(np.asarray(qcschema.wavefunction.scf_orbitals_b), nao, nmo)
coeff_b = reshape_2(qcschema.wavefunction.scf_orbitals_b, nao, nmo)

if coeff_a is not None:
basis_transform = ElectronicBasisTransform(
Expand All @@ -117,10 +116,10 @@ def reshape_4(arr, dim):
two_body_mo: TwoBodyElectronicIntegrals

if qcschema.wavefunction.scf_fock_mo_a is not None:
hij_mo = reshape_2(np.asarray(qcschema.wavefunction.scf_fock_mo_a), nmo)
hij_mo = reshape_2(qcschema.wavefunction.scf_fock_mo_a, nmo)

if qcschema.wavefunction.scf_fock_mo_b is not None:
hij_mo_b = reshape_2(np.asarray(qcschema.wavefunction.scf_fock_mo_b), nmo)
hij_mo_b = reshape_2(qcschema.wavefunction.scf_fock_mo_b, nmo)

if hij_mo is not None:
one_body_mo = OneBodyElectronicIntegrals(ElectronicBasis.MO, (hij_mo, hij_mo_b))
Expand All @@ -135,16 +134,16 @@ def reshape_4(arr, dim):
ints.append(one_body_mo)

if qcschema.wavefunction.scf_eri_mo_aa is not None:
eri_mo = reshape_4(np.asarray(qcschema.wavefunction.scf_eri_mo_aa), nmo)
eri_mo = reshape_4(qcschema.wavefunction.scf_eri_mo_aa, nmo)

if qcschema.wavefunction.scf_eri_mo_ba is not None:
eri_mo_ba = reshape_4(np.asarray(qcschema.wavefunction.scf_eri_mo_ba), nmo)
eri_mo_ba = reshape_4(qcschema.wavefunction.scf_eri_mo_ba, nmo)

if qcschema.wavefunction.scf_eri_mo_bb is not None:
eri_mo_bb = reshape_4(np.asarray(qcschema.wavefunction.scf_eri_mo_bb), nmo)
eri_mo_bb = reshape_4(qcschema.wavefunction.scf_eri_mo_bb, nmo)

if qcschema.wavefunction.scf_eri_mo_ab is not None:
eri_mo_ab = reshape_4(np.asarray(qcschema.wavefunction.scf_eri_mo_ab), nmo)
eri_mo_ab = reshape_4(qcschema.wavefunction.scf_eri_mo_ab, nmo)

if eri_mo is not None:
two_body_mo = TwoBodyElectronicIntegrals(
Expand Down

0 comments on commit e048860

Please sign in to comment.