Skip to content

Commit

Permalink
Add first-class check for nested datanodes in math/arithmetic ops. (#…
Browse files Browse the repository at this point in the history
…5466)

Signed-off-by: Michal Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed May 15, 2024
1 parent 91f89d9 commit 17458b7
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
51 changes: 51 additions & 0 deletions dali/python/nvidia/dali/ops/_operators/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,53 @@ def _generate_input_desc(categories_idx, integers, reals):
return input_desc


def _has_nested_datanodes(value, visited):
i = id(value)
if i in visited:
return False
visited.add(i)
for x in value:
if isinstance(x, _DataNode):
return True
if isinstance(x, (list, tuple)):
if _has_nested_datanodes(value, visited):
return True
return False


def _check_nested_datanode(op, arg, value):
if isinstance(value, (list, tuple)):
if _has_nested_datanodes(value, set()):
input_keyword = "argument" if isinstance(arg, str) else "input"
raise TypeError(
f"The {input_keyword} {repr(arg)} of operator `{op}` must be a `DataNode` or "
f"a compatible constant. "
f"Got a `{type(value).__name__}` with nested `DataNode`(s).\n"
f"Did you pass a return value of an operator producing multiple outputs?"
)


_op_display_name = {
"add": "+",
"sub": "-",
"plus": "(unary) +",
"minus": "(unary) -",
"mul": "*",
"pow": "**",
"div": "//",
"fdiv": "/",
"bitand": "&",
"bitor": "|",
"bitxor": "^",
"eq": "==",
"neq": "!=",
"lt": "<",
"leq": "<=",
"gt": ">",
"geq": ">=",
}


def _arithm_op(name, *inputs, definition_frame_end=None):
"""
Create arguments for ArithmeticGenericOp and call it with supplied inputs.
Expand All @@ -204,6 +251,10 @@ def _arithm_op(name, *inputs, definition_frame_end=None):
"""
import nvidia.dali.ops # Allow for late binding of the ArithmeticGenericOp from parent module.

display_name = _op_display_name.get(name, name)
for i, inp in enumerate(inputs):
_check_nested_datanode(display_name, i, inp)

categories_idxs, edges, integers, reals = _group_inputs(inputs)
input_desc = _generate_input_desc(categories_idxs, integers, reals)
expression_desc = "{}({})".format(name, input_desc)
Expand Down
34 changes: 34 additions & 0 deletions dali/test/python/operator_1/test_arithmetic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,3 +1189,37 @@ def pipe():
for device in ["cpu", "gpu"]:
with assert_raises(RuntimeError, glob=error_msg2):
impl(device, shape_a2, shape_b2)


def test_nested_datanode_error_math():
@pipeline_def(device_id=None, batch_size=1, num_threads=4)
def err_pipe():
u = fn.random.uniform(range=[0, 1])
v = fn.random.uniform(range=[0, 1])
return math.max([u, v], 5)

with assert_raises(
TypeError, glob="input 0 of operator `max` must be*" "Got a `list` with nested *DataNode"
):
_ = err_pipe()


@params(
*(
(x,)
for x in ("+", "-", "*", "/", "//", "**", "&", "|", "^", "==", "!=", "<", ">", "<=", ">=")
)
)
def test_nested_datanode_error_arithm(op):
print(op)

@pipeline_def(device_id=None, batch_size=1, num_threads=4)
def err_pipe():
u = fn.random.uniform(range=[0, 1]) # noqa(F841)
v = fn.random.uniform(range=[0, 1]) # noqa(F841)
return eval(f"u {op} [v]")

with assert_raises(
TypeError, glob=f"input 1 of operator `{op}` must be*" "Got a `list` with nested *DataNode"
):
_ = err_pipe()

0 comments on commit 17458b7

Please sign in to comment.