Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xarray Strategy causing significant MHKiT computation times (MHKiT slow) #331

Open
ryancoe opened this issue Jun 18, 2024 · 5 comments
Open
Assignees

Comments

@ryancoe
Copy link
Contributor

ryancoe commented Jun 18, 2024

Not a bug, but I was trying to analyze 30 years of hourly data with xarray, dask, and mhkit and having a lot of trouble getting things to run in a timely manner. I found that by simply re-writing some of the functions from mhkit in pure numpy, I got speed improvements on the order of 3000x. Looking at the mhkit code, I think this can only be due to converting between types?

function time
mhkit.wave.resource.significant_wave_height 2.4s
numpy 0.8ms
# %%
import xarray as xr
from mhkit.wave.resource import significant_wave_height
import numpy as np
import timeit
import matplotlib.pyplot as plt

# %% [markdown]
# # Load data

# %%
ds = xr.open_dataset('cape_hatteras_download_spectra_2000.nc',)
ds = ds.rename({'time_index': 'time'})

ds['frequency'].attrs['units'] = 'Hz'
ds['frequency'].attrs['long_name'] = 'Frequency'

ds['direction'] = ds['direction']
ds['direction'].attrs['units'] = 'rad'
ds['direction'].attrs['long_name'] = 'Direction'

ds.attrs['gid'] = ds.gid.item()
ds = ds.drop_vars('gid').squeeze()

ds = ds.to_array()
ds = ds.drop_vars('variable').squeeze()
ds = ds / (1025*9.81)

dso = ds.integrate('direction')
dso.attrs['units'] = 'm$^2$/Hz'
dso.attrs['long_name'] = 'Spectral density'
dso.name = 'S'
dso

# %% [markdown]
# # Timing

# %%
time = {}
n = 20

# %% [markdown]
# ## Using MHKiT

# %%
time['mhkit'] = timeit.timeit(
    lambda: significant_wave_height(dso.to_pandas().transpose()), number=n)/n

# %% [markdown]
# ## Using numpy

# %%
def moment(da, order=0):
    df = np.insert(np.diff(da['frequency']), 0, da['frequency'][0])
    m = np.sum(df*da.data*da.frequency.data**order, axis=1)
    return m

def sig_wave_height(da):
    return 4*np.sqrt(moment(da, 0))

# %%
time['numpy'] = timeit.timeit(
    lambda: sig_wave_height(dso), number=n)/n

# %%
time

# %%
time['mhkit']/time['numpy']

# %% [markdown]
# # Check that they agree

# %%
significant_wave_height(dso.to_pandas().transpose()).to_numpy().squeeze() - sig_wave_height(dso)

cape_hatteras_download_spectra_2000.nc.zip

@ssolson
Copy link
Contributor

ssolson commented Jun 18, 2024

Thanks Ryan. I found the same and we have discussed in #327.

The issue is around our approach to xarray and not testing the time of functions. I am addressing this through the testing of notebooks and we will be timing them going forward in #330 .

wrt to xarray we are adjusting our strategy moving from a simplistic everything in xarray to a tailored approach due to the slowness this introduced.

@akeeste This issue can serve as the foundation for our xarray strategy.

@ssolson ssolson changed the title MHKiT is slow Xarray Strategy causing significant MHKiT computation times (MHKiT slow) Jun 18, 2024
@ryancoe
Copy link
Contributor Author

ryancoe commented Jun 18, 2024

I think that the best concept is to write your functions to accept np.ndarrays as arguments. This will then generally work regardless if you pass np.ndarrays or xr.DataArrays or probably pd.DataFrames and it will be pretty fast.

# %%
import xarray as xr
from mhkit.wave.resource import significant_wave_height
import numpy as np
import timeit
import matplotlib.pyplot as plt

# %% [markdown]
# # Load data

# %%
ds = xr.open_dataset('cape_hatteras_download_spectra_2000.nc',)
ds = ds.rename({'time_index': 'time'})

ds['frequency'].attrs['units'] = 'Hz'
ds['frequency'].attrs['long_name'] = 'Frequency'

ds['direction'] = ds['direction']
ds['direction'].attrs['units'] = 'rad'
ds['direction'].attrs['long_name'] = 'Direction'

ds.attrs['gid'] = ds.gid.item()
ds = ds.drop_vars('gid').squeeze()

ds = ds.to_array()
ds = ds.drop_vars('variable').squeeze()
ds = ds / (1025*9.81)

dso = ds.integrate('direction')
dso.attrs['units'] = 'm$^2$/Hz'
dso.attrs['long_name'] = 'Spectral density'
dso.name = 'S'
dso

# %% [markdown]
# # Timing

# %%
time = {}
n = 20

# %% [markdown]
# ## Using MHKiT

# %%
time['mhkit'] = timeit.timeit(
    lambda: significant_wave_height(dso.to_pandas().transpose()), number=n)/n

# %% [markdown]
# ## Using numpy

# %%
def moment(S,f, order=0):
    df = np.insert(np.diff(f), 0, f[0])
    m = np.sum(df*S*f**order, axis=1)
    return m

def sig_wave_height(S,f):
    return 4*np.sqrt(moment(S,f, 0))

# %%
time['numpy'] = timeit.timeit(
    lambda: sig_wave_height(dso, dso['frequency']), number=n)/n

# %%
time

# %%
time['mhkit']/time['numpy']

# %% [markdown]
# # Check that they agree

# %%
significant_wave_height(dso.to_pandas().transpose()).to_numpy().squeeze() - sig_wave_height(dso, dso['frequency'])

@akeeste
Copy link
Contributor

akeeste commented Jun 18, 2024

Thanks Ryan. I found the same and we have discussed in #327.

The issue is around our approach to xarray and not testing the time of functions. I am addressing this through the testing of notebooks and we will be timing them going forward in #330 .

wrt to xarray we are adjusting our strategy moving from a simplistic everything in xarray to a tailored approach due to the slowness this introduced.

@akeeste This issue can serve as the foundation for our xarray strategy.

Agreed, the complexity of xarray was a lot slower than anticipated. I generally like this approach of simplifying internal functionality while maintaining a flexible IO for the user's chosen data types.

@akeeste
Copy link
Contributor

akeeste commented Jun 18, 2024

I've been working through this issue more today, using the significant wave height function as an example. In this case, the core problem is that a large pandas DataFrames (with a frequency dimension and variables across time) is being converted to a Dataset of one dimension and 8000+ variables, instead of Dataset with 1 variable (or DataArray) with 2 dimensions (frequency x time). This is the default behavior when using native xarray/pandas functions to convert DataFrames to Datasets.

The conversion back and forth with thousands of xr.Dataset variables is slow and applying mathematical functions to xr.Datasets with thousands of variables is slow. Likely similar to looping through all 8000+ time stamps and converting them or applying some mathematical function to each one instead of treating them as a tensor. There's a lot of historical data that ends up like this in MHKiT (time x frequency, etc). Often it should not actually contain multiple variables but has multiple dimensions.

In my testing I saw the following increases in computational expense versus the pure numpy implementation above:

  • input xarray directly in MHKiT - 10x expense
  • input pandas in MHKiT (converted to multivariate Dataset) - 1000x expense
  • input pandas in MHKiT (converted to DataArray / univariate Dataset) - 10x expense

My proposed solutions, pending more rigorous testing and assessment of other functions:

  • In the short-term, use the type_handling module to convert pd.DataFrames to xr.DataArrays instead of many variable xr.Datasets wherever possible. This will bring our expense down to what it was in v0.7.0
  • Long term, moving internal functionality to pure numpy would be fastest. This is a lot more overhead and probably not feasible everywhere

@akeeste
Copy link
Contributor

akeeste commented Jun 18, 2024

@ryancoe an immediate fix in your case-- MHKiT now allows you to input xarray to any MHKiT function instead of converting to pandas first. This should increase your speed across MHKiT immediately. Just call out frequency_dimension or time_dimension as required:

significant_wave_height(dso,frequency_dimension="frequency")

Edit--I forgot I changed a couple lines locally to wave.resource.significant_wave_height() to make this work (see #336). I'll get this parameter added to some of the wave resource functions to fully enable passing xarray:

def significant_wave_height(S, frequency_dimension="", frequency_bins=None, to_pandas=True):

and

    m0 = frequency_moment(S, 0, frequency_bins=frequency_bins, to_pandas=False, frequency_dimension=frequency_dimension).rename(
        {"m0": "Hm0"}
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants