Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add first-class check for nested datanodes in math/arithmetic ops. #5466

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions dali/python/nvidia/dali/ops/_operators/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,53 @@ def _generate_input_desc(categories_idx, integers, reals):
return input_desc


def _has_nested_datanodes(value, visited):
i = id(value)
if i in visited:
return False
visited.add(i)
for x in value:
if isinstance(x, _DataNode):
return True
if isinstance(x, (list, tuple)):
if _has_nested_datanodes(value, visited):
return True
return False


def _check_nested_datanode(op, arg, value):
if isinstance(value, (list, tuple)):
if _has_nested_datanodes(value, set()):
input_keyword = "argument" if isinstance(arg, str) else "input"
raise TypeError(
f"The {input_keyword} {repr(arg)} of operator `{op}` must be a `DataNode` or "
f"a compatible constant. "
f"Got a `{type(value).__name__}` with nested `DataNode`(s).\n"
f"Did you pass a return value of an operator producing multiple outputs?"
)


_op_display_name = {
"add": "+",
"sub": "-",
"plus": "(unary) +",
"minus": "(unary) -",
"mul": "*",
"pow": "**",
"div": "//",
"fdiv": "/",
"bitand": "&",
"bitor": "|",
"bitxor": "^",
"eq": "==",
"neq": "!=",
"lt": "<",
"leq": "<=",
"gt": ">",
"geq": ">=",
}


def _arithm_op(name, *inputs, definition_frame_end=None):
"""
Create arguments for ArithmeticGenericOp and call it with supplied inputs.
Expand All @@ -204,6 +251,10 @@ def _arithm_op(name, *inputs, definition_frame_end=None):
"""
import nvidia.dali.ops # Allow for late binding of the ArithmeticGenericOp from parent module.

display_name = _op_display_name.get(name, name)
for i, inp in enumerate(inputs):
_check_nested_datanode(display_name, i, inp)

categories_idxs, edges, integers, reals = _group_inputs(inputs)
input_desc = _generate_input_desc(categories_idxs, integers, reals)
expression_desc = "{}({})".format(name, input_desc)
Expand Down
34 changes: 34 additions & 0 deletions dali/test/python/operator_1/test_arithmetic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,3 +1189,37 @@
for device in ["cpu", "gpu"]:
with assert_raises(RuntimeError, glob=error_msg2):
impl(device, shape_a2, shape_b2)


def test_nested_datanode_error_math():
@pipeline_def(device_id=None, batch_size=1, num_threads=4)
def err_pipe():
u = fn.random.uniform(range=[0, 1])
v = fn.random.uniform(range=[0, 1])
return math.max([u, v], 5)

with assert_raises(
TypeError, glob="input 0 of operator `max` must be*" "Got a `list` with nested *DataNode"
):
_ = err_pipe()


@params(
*(
(x,)
for x in ("+", "-", "*", "/", "//", "**", "&", "|", "^", "==", "!=", "<", ">", "<=", ">=")
)
)
def test_nested_datanode_error_arithm(op):
print(op)

@pipeline_def(device_id=None, batch_size=1, num_threads=4)
def err_pipe():
u = fn.random.uniform(range=[0, 1]) # noqa(F841)
Dismissed Show dismissed Hide dismissed
v = fn.random.uniform(range=[0, 1]) # noqa(F841)
Dismissed Show dismissed Hide dismissed
return eval(f"u {op} [v]")

with assert_raises(
TypeError, glob=f"input 1 of operator `{op}` must be*" "Got a `list` with nested *DataNode"
):
_ = err_pipe()
Loading