Skip to content

Commit

Permalink
rm nvtabular
Browse files Browse the repository at this point in the history
  • Loading branch information
cwharris committed Jul 26, 2024
1 parent 62a81b9 commit 70948e1
Show file tree
Hide file tree
Showing 22 changed files with 58 additions and 2,325 deletions.
9 changes: 0 additions & 9 deletions morpheus/utils/column_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,13 @@
import logging
import re
import typing
import warnings
from datetime import datetime
from functools import partial

import pandas as pd

import cudf

if (typing.TYPE_CHECKING):
with warnings.catch_warnings():
# Ignore warning regarding tensorflow not being installed
warnings.filterwarnings("ignore", message=".*No module named 'tensorflow'", category=UserWarning)
import nvtabular as nvt

logger = logging.getLogger(f"morpheus.{__name__}")

DEFAULT_DATE = '1970-01-01T00:00:00.000000+00:00'
Expand Down Expand Up @@ -749,7 +742,6 @@ class DataFrameInputSchema:
input_columns: typing.Dict[str, str] = dataclasses.field(init=False, repr=False)
output_columns: typing.List[tuple[str, str]] = dataclasses.field(init=False, repr=False)

nvt_workflow: "nvt.Workflow" = dataclasses.field(init=False, repr=False)
prep_dataframe: typing.Callable[[pd.DataFrame], typing.List[str]] = dataclasses.field(init=False, repr=False)

def __post_init__(self):
Expand Down Expand Up @@ -797,4 +789,3 @@ def __post_init__(self):
json_cols=self.json_columns,
preserve_re=self.preserve_columns)

self.nvt_workflow = None
43 changes: 20 additions & 23 deletions morpheus/utils/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import fsspec
import pandas as pd
from merlin.core.utils import Distributed

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,7 +68,6 @@ def __init__(self,
download_method: typing.Union[DownloadMethods, str] = DownloadMethods.DASK_THREAD,
dask_heartbeat_interval: str = "30s"):

self._merlin_distributed = None
self._dask_heartbeat_interval = dask_heartbeat_interval

download_method = os.environ.get("MORPHEUS_FILE_DOWNLOAD_TYPE", download_method)
Expand Down Expand Up @@ -99,20 +97,25 @@ def get_dask_cluster(self):
Returns
-------
dask_cuda.LocalCUDACluster
dask.distributed.LocalCluster
"""

with Downloader._mutex:
if Downloader._dask_cluster is None:
import dask
import dask.distributed
import dask_cuda.utils

logger.debug("Creating dask cluster...")

n_workers = dask_cuda.utils.get_n_gpus()
threads_per_worker = mp.cpu_count() // n_workers
Downloader._dask_cluster = dask.distributed.LocalCluster(start=True,
processes=self.download_method != "dask_thread")

Downloader._dask_cluster = dask_cuda.LocalCUDACluster(n_workers=n_workers,
threads_per_worker=threads_per_worker)
# n_workers = dask_cuda.utils.get_n_gpus()
# threads_per_worker = mp.cpu_count() // n_workers

# Downloader._dask_cluster = dask_cuda.LocalCUDACluster(n_workers=n_workers,
# threads_per_worker=threads_per_worker)

logger.debug("Creating dask cluster... Done. Dashboard: %s", Downloader._dask_cluster.dashboard_link)

Expand All @@ -127,24 +130,18 @@ def get_dask_client(self):
dask.distributed.Client
"""
import dask.distributed
return dask.distributed.Client(self.get_dask_cluster())

# Up the heartbeat interval which can get violated with long download times
dask.config.set({"distributed.client.heartbeat": self._dask_heartbeat_interval})
def close(self):
"""Close the dask cluster if it exists."""
if (self._dask_cluster is not None):
logger.debug("Stopping dask cluster...")

if (self._merlin_distributed is None):
with warnings.catch_warnings():
# Merlin.Distributed will warn if a client already exists, the client in question is the one created
# and are explicitly passing to it in the constructor.
warnings.filterwarnings("ignore",
message="Existing Dask-client object detected in the current context.*",
category=UserWarning)
self._merlin_distributed = Distributed(client=dask.distributed.Client(self.get_dask_cluster()))
self._dask_cluster.close()

return self._merlin_distributed
self._dask_cluster = None

def close(self):
"""Cluster management is handled by Merlin.Distributed"""
pass
logger.debug("Stopping dask cluster... Done.")

def download(self,
download_buckets: fsspec.core.OpenFiles,
Expand All @@ -169,8 +166,8 @@ def download(self,
if (self._download_method.startswith("dask")):
# Create the client each time to ensure all connections to the cluster are closed (they can time out)
with self.get_dask_client() as dist:
dfs = dist.client.map(download_fn, download_buckets)
dfs = dist.client.gather(dfs)
dfs = dist.map(download_fn, download_buckets)
dfs = dist.gather(dfs)

else:
# Simply loop
Expand Down
13 changes: 0 additions & 13 deletions morpheus/utils/nvt/__init__.py

This file was deleted.

123 changes: 0 additions & 123 deletions morpheus/utils/nvt/decorators.py

This file was deleted.

17 changes: 0 additions & 17 deletions morpheus/utils/nvt/extensions/__init__.py

This file was deleted.

27 changes: 0 additions & 27 deletions morpheus/utils/nvt/extensions/morpheus_ext.py

This file was deleted.

Loading

0 comments on commit 70948e1

Please sign in to comment.