-
Notifications
You must be signed in to change notification settings - Fork 0
/
nc_writer.py
96 lines (78 loc) · 3.08 KB
/
nc_writer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import warnings
from shenfun import MixedTensorProductSpace
# https://github.com/Unidata/netcdf4-python/blob/master/examples/mpi_example.py
try:
from netCDF4 import Dataset
except:
warnings.warn('netcdf not installed')
__all__ = ('NCWriter',)
class NCWriter(object):
"""Class for writing data in shenfun to netcdf format
args:
ncname string Name of netcdf file to be created
names list of strings Names of fields to be stored
T TensorProductSpace Instance of a TensorProductSpace
Must be the same as the space used
for storing with 'write_tstep'
and 'write_slice_tstep'
clobber boolean
"""
def __init__(self, ncname, names, T, **kwargs):
self.f = Dataset(ncname, "w", parallel=True, comm=T.comm, **kwargs)
self.T = T
self.N = T.shape()
self.names = names
self._dtype = 'f8'
self.f.createDimension('t', None)
self.dims=['t']
self.nc_t = self.f.createVariable('t', self._dtype, ('t'))
self.nc_t.set_collective(True)
x = T.mesh()
s = self.T.local_slice(False)
for i, xi in enumerate(x):
xyz = {0:'x', 1:'y', 2:'z'}[i]
self.f.createDimension(xyz, np.squeeze(x[i]).size)
nc_xyz = self.f.createVariable(xyz, self._dtype, (xyz))
self.dims.append(xyz)
nc_xyz[s[i]] = np.squeeze(x[i][s[i]])
self.handles = dict()
for i,name in enumerate(names):
self.handles[i] = self.f.createVariable(name, self._dtype, self.dims)
# switch to collective mode, rewrite the data.
self.handles[i].set_collective(True)
self.f.sync()
def write_tstep(self, tstep, u):
"""Write field u to netcdf format at a given time step
args:
tstep int Time step
u Function/Array The field to be stored
"""
assert isinstance(u, np.ndarray)
# update time
it = self.nc_t.size
print(it)
self.nc_t[it] = tstep
if isinstance(self.T, MixedTensorProductSpace):
assert self.T.ndim() == len(u.shape[1:])
assert len(self.names) == u.shape[0]
s = self.T.local_slice(False)
for i in range(u.shape[0]):
if self.T.ndim() == 3:
self.handles[i][it,s[0],s[1],s[2]] = u[i]
elif self.T.ndim() == 2:
self.handles[i][it, s[0], s[1]] = u[i]
else:
raise(NotImplementedError)
else:
assert len(self.names) == 1
s = self.T.local_slice(False)
if self.T.ndim() == 3:
self.handles[0][it, s[0], s[1], s[2]] = u[:]
elif self.T.ndim() == 2:
self.handles[0][it, s[0], s[1]] = u[:]
else:
raise(NotImplementedError)
self.f.sync()
def close(self):
self.f.close()