Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read MDS in tiles using Zarr #331

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
20 changes: 20 additions & 0 deletions xmitgcm/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

import os

from .utils import parse_meta_file


def collect_meta(path: str):

output = {}
for m in (m for m in os.listdir(path) if m[-5:] == ".meta"):
try:
meta = parse_meta_file(os.path.join(path, m))
if meta["basename"] not in output:
output[meta["basename"]] = []
meta["filename"] = m
output[meta["basename"]].append(meta)
except:
print(f"Failed parsing: {m}")

return output
250 changes: 250 additions & 0 deletions xmitgcm/mds_tilez.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@

from collections.abc import Container
from collections import UserDict
from dataclasses import dataclass
import json
import os.path
import re

from fsspec.implementations.reference import ReferenceFileSystem
import numpy as np
import xarray as xr

from xmitgcm.utils import parse_meta_file

@dataclass()
class Chunk():
"""
Handles one chunk of one variable. For instance, it could be:
S.0000000000.001.001.data
S.0000000000.001.001.meta

metafilename = "/Users/castelao/work/projects/others/MIT_tiles/data/mitgcm/S.0000000000.001.001.meta"
"""
filename: str
metadata: dict

def __fspath__(self):
return os.path.join(self.root, self._fdata)

def __getitem__(self, key):
return self._labels[key]

@property
def _labels(self):
size = np.prod([(d[2] - d[1] + 1) for d in self.metadata['dimList']])
size *= self.metadata['dataprec'].itemsize # f32

# return {f"{v}/{self.index}": [self.filename, i*size, (i+1)*size] for i,v in enumerate(self.varnames)}
return {v: [self.filename, i*size, (i+1)*size] for i,v in enumerate(self.varnames)}

@property
def varnames(self):
try:
return self._varnames
except:
if "fldList" in self.metadata:
self._varnames = self.metadata['fldList']
else:
self._varnames = os.path.basename(self.filename).split('.')[0]
return self._varnames

@property
def dtype(self):
return self.metadata['dataprec'].str

@property
def index(self):
idx = []
for d in self.metadata['dimList']:
idx.append(str((d[1] - 1) // (d[2] - d[1]+1)))

return ".".join(idx)

@property
def missing_value(self):
if self.metadata['dataprec'].kind == 'i':
return int(self.metadata['missingValue'])
else:
return float(self.metadata['missingValue'])

@property
def shape(self):
return tuple(s[0] for s in self.metadata['dimList'])

@property
def chunks(self):
return tuple(s[2] for s in self.metadata['dimList'])

@property
def time_step_number(self):
return int(self.metadata['timeStepNumber'])

@staticmethod
def from_meta(filename):
assert os.path.exists(filename.replace(".meta", ".data"))
metadata = parse_meta_file(filename)
chunk = Chunk(
filename=filename.replace(".meta", ".data"),
metadata=metadata)
return chunk


@dataclass()
class VarZ():
varname: str
data: dict
chunks: tuple
dtype: str
fill_value: float
shape: tuple

def __getitem__(self, key):
print(f"VarZ.getitem: {key}")

assert key[:len(self.varname)+1] == f"{self.varname}/"
k = key[len(self.varname)+1:]
if k == f".zattrs":
return self._zattrs()
elif k == f".zarray":
return self._zarray()
else:
ti, xyi = k.split(".", 1)
ts = sorted(self.data.keys())[int(ti)]
return self.data[ts][xyi]

def __iter__(self):
print(f"VarZ.__iter__()")
yield from [
f"{self.varname}/.zattrs",
f"{self.varname}/.zarray",
]
idx_t = sorted(self.data.keys())
for ts in idx_t:
yield from (f"{self.varname}/{idx_t.index(ts)}.{i}"
for i in self.data[ts])

def _zattrs(self):
# How to guess this?
if len(self.shape) == 3:
dims = ["time", "lon", "lat", "depth"]
elif len(self.shape) == 2:
dims = ["time", "lon", "lat"]

return json.dumps({
# "_ARRAY_DIMENSIONS": ["lon", "lat", "depth"]
"_ARRAY_DIMENSIONS": dims
})

def _zarray(self):
return json.dumps({
"chunks": [1] + list(self.chunks),
"compressor": None, # fixed
"dtype": self.dtype,
"fill_value": self.fill_value,
"filters": None, # fixed
"order": "C", # fixed
"shape": [len(self.data.keys())] + list(self.shape),
"zarr_format": 2 # fixed
})

def push(self, chunk):
assert self.varname in chunk.varnames
assert self.chunks == chunk.chunks
assert self.dtype == chunk.dtype
assert self.shape == chunk.shape
assert self.fill_value == chunk.missing_value

if chunk.time_step_number not in self.data:
self.data[chunk.time_step_number] = {}
self.data[chunk.time_step_number][chunk.index] = chunk[self.varname]

def push_from_meta(self, filename):
d = Chunk.from_meta(filename)
self.push(d)

@staticmethod
def from_chunk(chunk, varname=None):
if varname is None:
assert len(chunk.varnames) == 1
varname = chunk.varnames[0]

assert varname in chunk.varnames

v = VarZ(varname=varname, data={}, chunks=chunk.chunks,
shape=chunk.shape,
dtype=chunk.dtype, fill_value=chunk.missing_value)
v.push(chunk)
return v

@staticmethod
def from_meta(filename, varname=None):
return VarZ.from_chunk(Chunk.from_meta(filename))


class TileZ(UserDict):
def __getitem__(self, key):
print(f"TileZ.__getitem__(): {key}")

if key == ".zgroup":
return self._zgroup()
for v in self.data:
if key.startswith(v):
return self.data[v][key]

return super().__getitem__(key)

def __contains__(self, item):
print(f"TileZ.__contains__(): {item}")
if item in (".zgroup"):
return True
for v in self.data:
if item in self.data[v]:
return True
return False

def __iter__(self):
print(f"TileZ.__iter__()")
yield from (".zgroup",)
for v in self.data:
yield from self.data[v].__iter__()

def _zgroup(self):
return json.dumps({"zarr_format": 2})

def as_dataset(self):
fs = ReferenceFileSystem(fo=self, target_protocol='file')
mapper = fs.get_mapper("")
ds = xr.open_zarr(mapper, consolidated=False)
return ds

def get(self, key, default=None):
print(f"TileZ.get(): {key} / {default}")
return self.data.get(key, default)
#return super().get(*args)

def len(self):
return len([v for v in self])

def push(self, chunk):
for v in chunk.varnames:
if v not in self.data:
self.data[v] = VarZ.from_chunk(chunk, varname=v)
self.data[v].push(chunk)

def push_from_meta(self, filename):
c = Chunk.from_meta(filename)
self.push(c)

def scan(self, path):
pattern = re.compile('\w+\.\d+\.meta')
filenames = (m for m in os.listdir(path) if pattern.match(m))
for mfilename in sorted(filenames):
self.push_from_meta(os.path.join(path, mfilename))

def values(self):
print(f"TileZ.values()")
yield self._zgroup()
for v in self.data:
# yield from self.data[v].values()
yield from [self.data[v][k] for k in self.data[v]]
36 changes: 36 additions & 0 deletions xmitgcm/test/test_mds_tilez.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

import json
import numpy as np
import pytest
import tempfile

from xmitgcm.mds_tilez import VarZ

@pytest.fixture(scope="session")
def var_lat(tmp_path_factory):
return tmp_path_factory.mktemp("lat")


def test_var_zattrs(var_lat):
v = VarZ(var_lat, 'lat')

attrs = v['.zattrs']
attrs = json.loads(attrs)
assert "_ARRAY_DIMENSIONS" in attrs
assert attrs["_ARRAY_DIMENSIONS"] == ["lat"]


def test_var_zarray(var_lat):
v = VarZ(var_lat, 'lat')

attrs = v['.zarray']
attrs = json.loads(attrs)
assert "chunks" in attrs

def test_var_data(var_lat):
v = VarZ(var_lat, 'lat')

data = v['0']
data = np.frombuffer(data)
assert np.all(data == np.array([10., 11., 12.]))