Skip to content

Commit

Permalink
refactor pymatviz imports as namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Aug 22, 2024
1 parent e467338 commit ee35ac5
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 48 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.5
rev: v0.6.1
hooks:
- id: ruff
args: [--fix]
Expand All @@ -30,7 +30,7 @@ repos:
args: [--check-filenames]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.0
rev: v1.11.1
hooks:
- id: mypy
exclude: (tests|examples)/
Expand Down
4 changes: 2 additions & 2 deletions aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def checkpoint_model(
if checkpoint_endpoint == "local":
os.makedirs(f"{ROOT}/models", exist_ok=True)
checkpoint_path = (
f"{ROOT}/models/{timestamp+'-' if timestamp else ''}{run_name}-{epochs}.pth"
f"{ROOT}/models/{timestamp + '-' if timestamp else ''}{run_name}-{epochs}.pth"
)
torch.save(checkpoint_dict, checkpoint_path)

Expand All @@ -438,7 +438,7 @@ def checkpoint_model(
), "can't save model checkpoint to Weights and Biases, wandb.run is None"
torch.save(
checkpoint_dict,
f"{wandb.run.dir}/{timestamp+'-' if timestamp else ''}{run_name}-{epochs}.pth",
f"{wandb.run.dir}/{timestamp + '-' if timestamp else ''}{run_name}-{epochs}.pth",
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@
import os

import pandas as pd
import pymatviz as pmv
from matminer.datasets import load_dataset
from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatviz import sankey_from_2_df_cols, spacegroup_sunburst
from pymatviz.utils import crystal_sys_from_spg_num
from tqdm import tqdm

import aviary.wren.utils as wren_utils
from aviary import ROOT
from aviary.wren.utils import (
get_protostructure_label_from_aflow,
get_protostructure_label_from_spglib,
)
from examples.wrenformer.mat_bench import DATA_PATHS

__author__ = "Janosh Riebesell"
Expand All @@ -31,92 +27,92 @@


# %%
df_perovskites = pd.read_json(DATA_PATHS["matbench_perovskites"]).set_index("mbid")
df_perovskites = df_perovskites.rename(columns={"wyckoff": "spglib_wyckoff"})
df_perovskites["structure"] = [
Structure.from_dict(struct) for struct in df_perovskites.structure
]
df_perov = pd.read_json(DATA_PATHS["matbench_perovskites"]).set_index("mbid")
df_perov = df_perov.rename(columns={"wyckoff": "spglib_wyckoff"})
df_perov["structure"] = df_perov.structure.map(Structure.from_dict)


# %%
# takes ~6h (when running uninterrupted)
for idx, struct in tqdm(df_perovskites.structure.items(), total=len(df_perovskites)):
if pd.isna(df_perovskites.aflow_wyckoff[idx]):
df_perovskites.loc[idx, "aflow_wyckoff"] = get_protostructure_label_from_aflow(
struct, "/Users/janosh/bin/aflow"
for idx, struct in tqdm(df_perov.structure.items(), total=len(df_perov)):
if pd.isna(df_perov.aflow_wyckoff[idx]):
df_perov.loc[idx, "aflow_wyckoff"] = (
wren_utils.get_protostructure_label_from_aflow(
struct, "/Users/janosh/bin/aflow"
)
)


# %%
# takes ~30 sec
for struct in tqdm(df_perovskites.structure, total=len(df_perovskites)):
get_protostructure_label_from_spglib(struct)
for struct in tqdm(df_perov.structure, total=len(df_perov)):
wren_utils.get_protostructure_label_from_spglib(struct)


# %%
df_perovskites.dropna().query("wyckoff != aflow_wyckoff")
df_perov.dropna().query("wyckoff != aflow_wyckoff")


# %%
print(
"Percentage of materials with spglib label != aflow label: "
f"{len(df_perovskites.query('wyckoff != aflow_wyckoff')) / len(df_perovskites):.0%}"
f"{len(df_perov.query('wyckoff != aflow_wyckoff')) / len(df_perov):.0%}"
)


# %%
df_perovskites.drop("structure", axis=1).to_csv(
df_perov.drop("structure", axis=1).to_csv(
f"{ROOT}/datasets/matbench_perovskites_protostructure_labels.csv"
)


# %%
df_perovskites = pd.read_csv(
df_perov = pd.read_csv(
f"{ROOT}/datasets/matbench_perovskites_protostructure_labels.csv"
).set_index("mbid")


# %%
for src in ("aflow", "spglib"):
df_perovskites[f"{src}_spg_num"] = (
df_perovskites[f"{src}_wyckoff"].str.split("_").str[2].astype(int)
df_perov[f"{src}_spg_num"] = (
df_perov[f"{src}_wyckoff"].str.split("_").str[2].astype(int)
)


# %%
fig = spacegroup_sunburst(df_perovskites.spglib_spg)
fig = pmv.spacegroup_sunburst(df_perov.spglib_spg)
fig.update_layout(title=dict(text="Spglib Spacegroups", x=0.5, y=0.93))
# fig.write_image(f"{MODULE_DIR}/plots/matbench_perovskites_aflow_sunburst.pdf")


# %%
fig = spacegroup_sunburst(df_perovskites.aflow_spg, title="Aflow")
fig = pmv.spacegroup_sunburst(df_perov.aflow_spg, title="Aflow")
fig.update_layout(title=dict(text="Aflow Spacegroups", x=0.5, y=0.85))
# fig.write_image(f"{MODULE_DIR}/plots/matbench_perovskites_spglib_sunburst.pdf")


# %%
df_perovskites = load_dataset("matbench_perovskites")
df_perov = load_dataset("matbench_perovskites")

df_perovskites["spglib_spg_num"] = df_perovskites.structure.map(
df_perov["spglib_spg_num"] = df_perov.structure.map(
lambda struct: SpacegroupAnalyzer(struct).get_space_group_number()
)


# %%
for src in ("aflow", "spglib"):
df_perovskites[f"{src}_crys_sys"] = df_perovskites[f"{src}_spg_num"].map(
crystal_sys_from_spg_num
df_perov[f"{src}_crys_sys"] = df_perov[f"{src}_spg_num"].map(
pmv.utils.crystal_sys_from_spg_num
)


# %%
fig = sankey_from_2_df_cols(df_perovskites, ["aflow_spg_num", "spglib_spg_num"])
fig = pmv.sankey_from_2_df_cols(df_perov, ["aflow_spg_num", "spglib_spg_num"])

fig.update_layout(title="Matbench Perovskites Aflow vs Spglib Spacegroups")


# %%
fig = sankey_from_2_df_cols(df_perovskites, ["aflow_crys_sys", "spglib_crys_sys"])
fig = pmv.sankey_from_2_df_cols(df_perov, ["aflow_crys_sys", "spglib_crys_sys"])

fig.update_layout(title="Aflow vs Spglib Crystal Systems")
4 changes: 2 additions & 2 deletions examples/wrenformer/mat_bench/make_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

import pandas as pd
import plotly.express as px
import pymatviz as pmv
from matbench import MatbenchBenchmark
from matbench.constants import CLF_KEY, REG_KEY
from matbench.metadata import mbv01_metadata as matbench_metadata
from pymatviz.powerups import add_identity_line
from sklearn.metrics import r2_score, roc_auc_score

from examples.wrenformer.mat_bench import DATA_PATHS
Expand Down Expand Up @@ -209,7 +209,7 @@
"value ": "Predicted formation energy (eV/atom)",
},
)
add_identity_line(fig)
pmv.powerups.add_identity_line(fig)

fig.update_layout(legend=dict(x=0.02, y=0.95, xanchor="left", title="Models"))

Expand Down
12 changes: 6 additions & 6 deletions examples/wrenformer/mat_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def _int_keys(dct: dict) -> dict:
return {int(k) if k.lstrip("-").isdigit() else k: v for k, v in dct.items()}


def recursive_dict_merge(d1: dict, d2: dict) -> dict:
def recursive_dict_merge(dict1: dict, dict2: dict) -> dict:
"""Merge two dicts recursively."""
for key in d2:
if key in d1 and isinstance(d1[key], dict) and isinstance(d2[key], dict):
recursive_dict_merge(d1[key], d2[key])
for key, val2 in dict2.items():
if key in dict1 and isinstance(dict1[key], dict) and isinstance(val2, dict):
recursive_dict_merge(dict1[key], val2)
else:
d1[key] = d2[key]
return d1
dict1[key] = val2
return dict1


def merge_json_on_disk(
Expand Down
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ no_implicit_optional = false
[tool.ruff]
line-length = 90
target-version = "py39"
extend-include = ["*.ipynb"]
lint.select = [
output-format = "concise"

[tool.ruff.lint]
select = [
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"D", # pydocstyle
Expand Down Expand Up @@ -105,7 +107,7 @@ lint.select = [
"W", # pycodestyle warning
"YTT", # flake8-2020
]
lint.ignore = [
ignore = [
"C408", # Unnecessary dict call - rewrite as a literal
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
Expand All @@ -116,8 +118,8 @@ lint.ignore = [
"PLR", # pylint refactor
"PT006", # pytest-parametrize-names-wrong-type
]
lint.pydocstyle.convention = "google"
lint.isort.known-third-party = ["wandb"]
pydocstyle.convention = "google"
isort.known-third-party = ["wandb"]

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D"]
Expand Down

0 comments on commit ee35ac5

Please sign in to comment.