diff --git a/dali/python/nvidia/dali/ops/_operators/math.py b/dali/python/nvidia/dali/ops/_operators/math.py index cee0d84d08..94d99f6961 100644 --- a/dali/python/nvidia/dali/ops/_operators/math.py +++ b/dali/python/nvidia/dali/ops/_operators/math.py @@ -190,6 +190,53 @@ def _generate_input_desc(categories_idx, integers, reals): return input_desc +def _has_nested_datanodes(value, visited): + i = id(value) + if i in visited: + return False + visited.add(i) + for x in value: + if isinstance(x, _DataNode): + return True + if isinstance(x, (list, tuple)): + if _has_nested_datanodes(value, visited): + return True + return False + + +def _check_nested_datanode(op, arg, value): + if isinstance(value, (list, tuple)): + if _has_nested_datanodes(value, set()): + input_keyword = "argument" if isinstance(arg, str) else "input" + raise TypeError( + f"The {input_keyword} {repr(arg)} of operator `{op}` must be a `DataNode` or " + f"a compatible constant. " + f"Got a `{type(value).__name__}` with nested `DataNode`(s).\n" + f"Did you pass a return value of an operator producing multiple outputs?" + ) + + +_op_display_name = { + "add": "+", + "sub": "-", + "plus": "(unary) +", + "minus": "(unary) -", + "mul": "*", + "pow": "**", + "div": "//", + "fdiv": "/", + "bitand": "&", + "bitor": "|", + "bitxor": "^", + "eq": "==", + "neq": "!=", + "lt": "<", + "leq": "<=", + "gt": ">", + "geq": ">=", +} + + def _arithm_op(name, *inputs, definition_frame_end=None): """ Create arguments for ArithmeticGenericOp and call it with supplied inputs. @@ -204,6 +251,10 @@ def _arithm_op(name, *inputs, definition_frame_end=None): """ import nvidia.dali.ops # Allow for late binding of the ArithmeticGenericOp from parent module. + display_name = _op_display_name.get(name, name) + for i, inp in enumerate(inputs): + _check_nested_datanode(display_name, i, inp) + categories_idxs, edges, integers, reals = _group_inputs(inputs) input_desc = _generate_input_desc(categories_idxs, integers, reals) expression_desc = "{}({})".format(name, input_desc) diff --git a/dali/test/python/operator_1/test_arithmetic_ops.py b/dali/test/python/operator_1/test_arithmetic_ops.py index 7a189845d3..55a99fb878 100644 --- a/dali/test/python/operator_1/test_arithmetic_ops.py +++ b/dali/test/python/operator_1/test_arithmetic_ops.py @@ -1189,3 +1189,37 @@ def pipe(): for device in ["cpu", "gpu"]: with assert_raises(RuntimeError, glob=error_msg2): impl(device, shape_a2, shape_b2) + + +def test_nested_datanode_error_math(): + @pipeline_def(device_id=None, batch_size=1, num_threads=4) + def err_pipe(): + u = fn.random.uniform(range=[0, 1]) + v = fn.random.uniform(range=[0, 1]) + return math.max([u, v], 5) + + with assert_raises( + TypeError, glob="input 0 of operator `max` must be*" "Got a `list` with nested *DataNode" + ): + _ = err_pipe() + + +@params( + *( + (x,) + for x in ("+", "-", "*", "/", "//", "**", "&", "|", "^", "==", "!=", "<", ">", "<=", ">=") + ) +) +def test_nested_datanode_error_arithm(op): + print(op) + + @pipeline_def(device_id=None, batch_size=1, num_threads=4) + def err_pipe(): + u = fn.random.uniform(range=[0, 1]) # noqa(F841) + v = fn.random.uniform(range=[0, 1]) # noqa(F841) + return eval(f"u {op} [v]") + + with assert_raises( + TypeError, glob=f"input 1 of operator `{op}` must be*" "Got a `list` with nested *DataNode" + ): + _ = err_pipe()