Skip to content

Commit

Permalink
add read_mmn and read_chk and associated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jimustafa committed Sep 15, 2023
1 parent 92ee82b commit 55e0028
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
::: wannier90io.parse_wout_iteration_info
::: wannier90io.read_amn
::: wannier90io.write_amn
::: wannier90io.read_chk
::: wannier90io.read_eig
::: wannier90io.write_eig
::: wannier90io.read_mmn
2 changes: 2 additions & 0 deletions src/wannier90io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from . import _schema
from ._amn import *
from ._chk import *
from ._eig import *
from ._mmn import *
from ._nnkp import *
from ._win import *
from ._wout import *
22 changes: 22 additions & 0 deletions src/wannier90io/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,22 @@ def info_eig(args):
print(f'Nb = {eig.shape[1]}')


def info_mmn(args):
with args.file:
(mmn, nnkpts) = w90io.read_mmn(args.file)

print(mmn.shape, nnkpts.shape)


def info_chk(args):
with args.file:
chk = w90io.read_chk(args.file)

print(chk['num_bands'])
print(chk['num_wann'])
print(chk['have_disentangled'])


def main():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest='subparser', required=True)
Expand Down Expand Up @@ -135,6 +151,12 @@ def main():
#
parser_info_eig = subparsers.add_parser('info-eig', parents=[parser_common])
parser_info_eig.set_defaults(func=info_eig)
#
parser_info_mmn = subparsers.add_parser('info-mmn', parents=[parser_common])
parser_info_mmn.set_defaults(func=info_mmn)
#
parser_info_chk = subparsers.add_parser('info-chk', parents=[parser_common])
parser_info_chk.set_defaults(func=info_chk)

args = parser.parse_args()
args.func(args)
47 changes: 47 additions & 0 deletions src/wannier90io/_chk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations
import typing

import numpy as np


__all__ = ['read_chk']


def read_chk(stream: typing.TextIO) -> dict:
"""
Read checkpoint
Arguments:
stream: a file-like stream
Returns:
dict
"""
chk = {}

chk['header'] = stream.readline()
chk['num_bands'] = Nb = int(stream.readline())
chk['num_exclude_bands'] = int(stream.readline())
if chk['num_exclude_bands'] > 0:
chk['num_exclude_bands'] = np.fromstring(stream.readline(), dtype=int)
chk['real_lattice'] = np.fromstring(stream.readline(), sep=' ', dtype=float).reshape((3, 3), order='F')
chk['recip_lattice'] = np.fromstring(stream.readline(), sep=' ', dtype=float).reshape((3, 3), order='F')
chk['num_kpts'] = Nk = int(stream.readline())
chk['mp_grid'] = np.fromstring(stream.readline(), sep=' ', dtype=int)
chk['kpt_latt'] = np.zeros((chk['num_kpts'], 3))
for idx in range(chk['num_kpts']):
chk['kpt_latt'][idx] = np.fromstring(stream.readline(), sep=' ', dtype=float)
chk['nntot'] = Nn = int(stream.readline())
chk['num_wann'] = Nw = int(stream.readline())
chk['checkpoint'] = stream.readline()
chk['have_disentangled'] = bool(int(stream.readline()))
if chk['have_disentangled']:
chk['omega_invariant'] = float(stream.readline())
chk['lwindow'] = np.loadtxt(stream, max_rows=(Nk*Nb), dtype=bool).reshape((Nk, Nb))
chk['nwindim'] = np.loadtxt(stream, max_rows=Nk, dtype=int)
chk['u_matrix_opt'] = np.loadtxt(stream, max_rows=(Nk*Nw*Nb), dtype=float).view(complex).reshape((Nk, Nw, Nb))
chk['u_matrix'] = np.loadtxt(stream, max_rows=(Nk*Nw*Nw), dtype=float).view(complex).reshape((Nw, Nw, Nk), order='F').transpose((2, 0, 1))
chk['m_matrix'] = np.loadtxt(stream, max_rows=(Nk*Nn*Nw*Nw), dtype=float).view(complex).reshape((Nw, Nw, Nn, Nk), order='F').transpose((3, 2, 0, 1))

return chk
37 changes: 37 additions & 0 deletions src/wannier90io/_mmn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations
import itertools
import typing

import numpy as np


__all__ = ['read_mmn']


def read_mmn(stream: typing.TextIO) -> tuple[np.ndarray, np.ndarray]:
"""
Read overlaps matrix
Arguments:
stream: a file-like stream
Returns:
overlaps matrix (Nk, Nn, Nb, Nb)
nnkps (Nk, Nn, 5)
"""
stream.readline() # header

[Nb, Nk, Nn] = np.fromstring(stream.readline(), sep=' ', dtype=int)

mmn = np.zeros((Nk, Nn, Nb, Nb), dtype=complex)
nnkpts = np.zeros((Nk, Nn, 5), dtype=int)

for (ik, ikb) in itertools.product(range(Nk), range(Nn)):
nnkpts[ik, ikb] = np.fromstring(stream.readline(), sep=' ', dtype=int)
mmn[ik, ikb] = np.loadtxt(stream, max_rows=(Nb*Nb)).view(complex).reshape((Nb, Nb), order='F')

nnkpts[:, :, 0] -= 1
nnkpts[:, :, 1] -= 1

return (mmn, nnkpts)
5 changes: 5 additions & 0 deletions tests/fixtures/fixtures.mk
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
W90 = $(abspath ../../wannier90.x)
POSTW90 = $(abspath ../../postw90.x)
W90CHK2CHK = $(abspath ../../w90chk2chk.x)

EXAMPLES_1 = $(foreach idx,01 02 03 04,example$(idx))
EXAMPLES_2 = $(foreach idx,05 06 07 09 10 11 13 17 18 19 20,example$(idx))
Expand Down Expand Up @@ -47,6 +48,7 @@ run-example01:
$(W90) -pp wannier
echo "write_xyz=true" >> wannier.win
$(W90) wannier
$(W90CHK2CHK) -export wannier

.PHONY: run-example02
run-example02:
Expand All @@ -55,13 +57,15 @@ run-example02:
echo $(WRITE_HR) >> wannier.win
echo "write_xyz=true" >> wannier.win
$(W90) wannier
$(W90CHK2CHK) -export wannier

.PHONY: run-example03
run-example03:
$(call normalize_seedname)
$(W90) -pp wannier
$(call modify_win)
$(W90) wannier
$(W90CHK2CHK) -export wannier
echo "" >> wannier_geninterp.kpt
echo "crystal" >> wannier_geninterp.kpt
head -n 1 wannier_band.kpt >> wannier_geninterp.kpt
Expand All @@ -75,6 +79,7 @@ run-example04:
$(call modify_win)
echo "geninterp_alsofirstder=true" >> wannier.win
$(W90) wannier
$(W90CHK2CHK) -export wannier
echo "" >> wannier_geninterp.kpt
echo "crystal" >> wannier_geninterp.kpt
head -n 1 wannier_band.kpt >> wannier_geninterp.kpt
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ else
echo "LIBS=-lblas -llapack" >> make.inc
fi

make
make wannier post w90chk2chk

popd
popd
Expand Down
36 changes: 36 additions & 0 deletions tests/test_chk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import itertools
import pathlib

import numpy as np
import pytest

import w90io


@pytest.mark.parametrize('example', ['example01', 'example02'])
def test_read_chk(wannier90, example):
with open(pathlib.Path(wannier90)/f'examples/{example}/wannier.chk.fmt', 'r') as fh:
chk = w90io.read_chk(fh)
umn = chk['u_matrix']
mmn_ref = chk['m_matrix']

with open(pathlib.Path(wannier90)/f'examples/{example}/wannier.nnkp', 'r') as fh:
nnkp = w90io.parse_nnkp_raw(fh.read())

with open(pathlib.Path(wannier90)/f'examples/{example}/wannier.mmn', 'r') as fh:
(mmn_test, nnkpts) = w90io.read_mmn(fh)
nnkpts1 = np.copy(nnkpts)
nnkpts1[:, :, 0] += 1
nnkpts1[:, :, 1] += 1

for (ik, ikb) in itertools.product(range(mmn_test.shape[0]), range(mmn_test.shape[1])):
mmn_test[ik, ikb] = np.dot(np.dot(umn[ik].conj().T, mmn_test[ik, ikb]), umn[nnkpts[:, :, 1][ik, ikb]])

assert np.allclose(np.array(nnkp['nnkpts']).reshape(nnkpts.shape), nnkpts1)
assert np.allclose(mmn_test, mmn_ref)
assert np.allclose(chk['real_lattice'][0], nnkp['direct_lattice']['a1'])
assert np.allclose(chk['real_lattice'][1], nnkp['direct_lattice']['a2'])
assert np.allclose(chk['real_lattice'][2], nnkp['direct_lattice']['a3'])
assert np.allclose(chk['recip_lattice'][0], nnkp['reciprocal_lattice']['b1'])
assert np.allclose(chk['recip_lattice'][1], nnkp['reciprocal_lattice']['b2'])
assert np.allclose(chk['recip_lattice'][2], nnkp['reciprocal_lattice']['b3'])
4 changes: 4 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ def test_cli(wannier90, example):
nnkp_file = pathlib.Path(wannier90)/f'examples/{example}/wannier.nnkp'
wout_file = pathlib.Path(wannier90)/f'examples/{example}/wannier.wout'
amn_file = pathlib.Path(wannier90)/f'examples/{example}/wannier.amn'
chk_file = pathlib.Path(wannier90)/f'examples/{example}/wannier.chk.fmt'
eig_file = pathlib.Path(wannier90)/f'examples/{example}/wannier.eig'
mmn_file = pathlib.Path(wannier90)/f'examples/{example}/wannier.mmn'

assert subprocess.run(['w90io', 'parse-win', win_file]).returncode == 0
assert subprocess.run(['w90io', 'parse-nnkp', nnkp_file]).returncode == 0
assert subprocess.run(['w90io', 'parse-wout-iteration-info', wout_file]).returncode == 0
assert subprocess.run(['w90io', 'info-amn', amn_file]).returncode == 0
assert subprocess.run(['w90io', 'info-chk', chk_file]).returncode == 0
assert subprocess.run(['w90io', 'info-eig', eig_file]).returncode == 0
assert subprocess.run(['w90io', 'info-mmn', mmn_file]).returncode == 0
18 changes: 18 additions & 0 deletions tests/test_mmn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pathlib

import numpy as np
import pytest
import w90io


@pytest.mark.parametrize('example', ['example01', 'example02', 'example04'])
def test_read_mmn(wannier90, example):
with open(pathlib.Path(wannier90)/f'examples/{example}/wannier.mmn', 'r') as fh:
(mmn, nnkpts) = w90io.read_mmn(fh)
nnkpts[:, :, 0] += 1
nnkpts[:, :, 1] += 1

with open(pathlib.Path(wannier90)/f'examples/{example}/wannier.nnkp', 'r') as fh:
nnkp = w90io.parse_nnkp_raw(fh.read())

assert np.allclose(nnkpts, np.asarray(nnkp['nnkpts'], dtype=int).reshape(nnkpts.shape))

0 comments on commit 55e0028

Please sign in to comment.