Skip to content

Commit

Permalink
Keep separate per-pipeline operator counters. Error out when "stealin…
Browse files Browse the repository at this point in the history
…g" subgraphs from other pipelines results in duplicate names. (#5506)

Prior to this change constructing exactly the same pipeline (e.g. by calling a function decorated with @pipeline_def) multiple times produced pipelines with different operator instance names and differently named operator instances and DataNodes. This PR changes that so that pipelines with the same structure have the same node names.
This is achieved by:

* Using a separate operator counter in each pipeline. When all nodes all instantiated within pipeline scope, no further action is required. This happens for the vast majority of cases.
* If a name collision occurs (only possible when "stealing" a subgraph from another pipeline), an error is raised.

Pipelines that are defined without a "current" pipeline have distinct operator instance names.

Additionally, there were some problems with operator discovery. I rewrote it to a much simpler DFS.

----
Signed-off-by: Michal Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Jun 11, 2024
1 parent f850059 commit 9dfdeac
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 54 deletions.
10 changes: 10 additions & 0 deletions dali/pipeline/operator/op_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,16 @@ class DLL_PUBLIC OpSpec {
return outputs_[idx].device;
}

DLL_PUBLIC inline void RenameInput(int idx, std::string name) {
DALI_ENFORCE_VALID_INDEX(idx, NumInput());
inputs_[idx].name = std::move(name);
}

DLL_PUBLIC inline void RenameOutput(int idx, std::string name) {
DALI_ENFORCE_VALID_INDEX(idx, NumOutput());
outputs_[idx].name = std::move(name);
}

DLL_PUBLIC inline auto &ArgumentInputs() const {
return argument_inputs_;
}
Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ int Pipeline::AddOperatorImpl(const OpSpec &const_spec, const std::string &inst_

DALI_ENFORCE(it != edge_names_.end(),
make_string("Data node \"", input_name, "\" requested as ", FormatInput(spec, i),
" to the operator is not known to the pipeline."));
" to operator \"", inst_name, "\" is not known to the pipeline."));

// Table of possible scenarios:
// Op location / requested input type / data location
Expand Down Expand Up @@ -360,7 +360,7 @@ int Pipeline::AddOperatorImpl(const OpSpec &const_spec, const std::string &inst_
DALI_ENFORCE(
it != edge_names_.end(),
make_string("Data node \"", input_name, "\" requested as ", FormatArgument(spec, arg_name),
" to operator is not known to the pipeline."));
" to operator \"", inst_name, "\" is not known to the pipeline."));

if (!it->second.has_cpu) {
assert(it->second.has_gpu);
Expand Down
9 changes: 9 additions & 0 deletions dali/python/backend_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2152,6 +2152,15 @@ PYBIND11_MODULE(backend_impl, m) {
py::return_value_policy::reference_internal)
.def("AddOutput", &OpSpec::AddOutput,
py::return_value_policy::reference_internal)
.def("RenameInput", &OpSpec::RenameInput, "idx"_a, "name"_a)
.def("RenameOutput", &OpSpec::RenameOutput, "idx"_a, "name"_a)
.def("InputName", &OpSpec::InputName, "idx"_a)
.def("InputDevice", &OpSpec::InputDevice, "idx"_a)
.def("OutputName", &OpSpec::OutputName, "idx"_a)
.def("OutputDevice", &OpSpec::OutputDevice, "idx"_a)
.def("NumInput", &OpSpec::NumInput)
.def("NumRegularInput", &OpSpec::NumRegularInput)
.def("NumOutput", &OpSpec::NumOutput)
DALI_OPSPEC_ADDARG(std::string)
DALI_OPSPEC_ADDARG(bool)
DALI_OPSPEC_ADDARG(int64)
Expand Down
2 changes: 1 addition & 1 deletion dali/python/nvidia/dali/data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, name, device="cpu", source=None):
self.source = source

def __str__(self):
return f'DataNode(name="{self.name}", device="{self.device}")'
return f'DataNode(name="{self.name}", device="{self.device}, source="{self.source}")'

__repr__ = __str__

Expand Down
35 changes: 30 additions & 5 deletions dali/python/nvidia/dali/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import threading
import tree
import warnings
import weakref
from itertools import count

import nvidia.dali.python_function_plugin
Expand Down Expand Up @@ -369,11 +370,16 @@ def __init__(self, inputs, arg_inputs, arguments, _processed_arguments, op):
op : Operator class.
Operator class containing the schema, and spec filled with `processed_arguments`.
"""
self._counter = _OpCounter()

if _Pipeline.current():
self._pipeline = weakref.ref(_Pipeline.current())
else:
self._pipeline = None
self._id = None
self._outputs = []
self._op = op
self._spec = op.spec.copy()
self._relation_id = self._counter.id
self._relation_id = None

if _conditionals.conditionals_enabled():
inputs, arg_inputs = _conditionals.apply_conditional_split_to_args(inputs, arg_inputs)
Expand Down Expand Up @@ -412,8 +418,13 @@ def _process_instance_name(self, arguments):
name = arguments.pop("name", None)
if name is not None:
self._name = name
self._autoname = False
else:
self._name = "__" + type(self._op).__name__ + "_" + str(self._counter.id)
has_pipeline = self.pipeline is not None
# to avoid mixing up global and per-pipeline ids
infix = "_" if has_pipeline else "_detached_"
self._name = "__" + type(self._op).__name__ + infix + str(self.id)
self._autoname = True

def _process_trace(self, arguments):
from nvidia.dali._debug_mode import _PipelineDebug
Expand Down Expand Up @@ -463,9 +474,21 @@ def _generate_outputs(self):
pipeline.add_sink(t)
self.append_output(t)

@property
def pipeline(self):
return None if self._pipeline is None else self._pipeline()

@property
def id(self):
return self._counter.id
if self._id is None:
if self.pipeline is None and _Pipeline.current():
self._pipeline = weakref.ref(_Pipeline.current())
if self.pipeline:
self._id = self.pipeline._next_op_id()
else:
self._id = _OpCounter().id

return self._id

@property
def inputs(self):
Expand All @@ -492,6 +515,8 @@ def name(self):

@property
def relation_id(self):
if self._relation_id is None:
self._relation_id = id(self)
return self._relation_id

@relation_id.setter
Expand Down Expand Up @@ -629,7 +654,7 @@ def __call__(self, *inputs, **kwargs):
)

# Tie the instances together
relation_id = op_instances[0].id
relation_id = op_instances[0].relation_id
for op in op_instances:
op.relation_id = relation_id

Expand Down
102 changes: 71 additions & 31 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# pylint: disable=no-member
from typing import Any, List, Tuple, Callable, Optional, Union, TypeVar, overload
from collections import deque
from nvidia.dali import backend as b
from nvidia.dali import types
from nvidia.dali import internal
Expand Down Expand Up @@ -235,6 +234,7 @@ def __init__(
self._num_threads = num_threads
self._device_id = device_id
self._seed = seed
self._next_op_id_counter = 0
self._exec_pipelined = exec_pipelined
# When initializing DALI, we do the following in order:
# * Discover the ops specified in Python, group the ExternalSources (_build_graph())
Expand Down Expand Up @@ -719,6 +719,48 @@ def __exit__(self, type, value, traceback):

return api_checker(self)

def _require_unique_names(self):
ops_by_name = {}
for op in self._ops:
ops = ops_by_name.get(op.name, [])
ops.append(op)
duplicate = {}
foreign = False
for name, ops in ops_by_name.items():
if len(ops) > 1:
duplicate[name] = ops
for op in ops:
if op.pipeline is not self:
foreign = True

if duplicate:
message = (
f"The pipeline is invalid because it contains operators with non-unique names:\n"
f"{duplicate}"
)
if foreign:
message += (
"\nThe likely cause is that the pipeline contains a subgraph "
"instantiated while a different pipeline was set as the current "
"pipeline (e.g. inside another pipeline's graph definition function).\n"
)
raise RuntimeError(message)

def _require_no_foreign_ops(self, message):
foreign = []
for op in self._ops:
if op.pipeline is not self:
foreign.append(op)
if foreign:
raise RuntimeError(
f"{message} because it contains operator(s) "
f"that were defined outside the pipeline scope:\n"
f"{[o.name for o in foreign]}\n"
f"All operators should be defined while the pipeline is set as the current "
f"pipeline. This happens automatically when defining the pipeline in a "
f"function decorated with `@pipeline_def`."
)

# Graph is constructed by backtracking from the output edges and the edges marked as sinks
def _build_graph(self, define_graph=None):
if define_graph is not None:
Expand Down Expand Up @@ -776,6 +818,10 @@ def contains_nested_datanode(nested):
_data_node._check(outputs[i])

self._ops = _collect_ops(list(outputs) + self._sinks)
self._require_unique_names()
if self._enable_checkpointing:
self._require_no_foreign_ops("The pipeline does not support checkpointing")

self._graph_outputs = outputs
self._setup_input_callbacks()
self._disable_pruned_external_source_instances()
Expand Down Expand Up @@ -960,6 +1006,11 @@ def _restore_state_from_checkpoint(self):
self._iterator_data = external_ctx_cpt.iterator_data
self._is_restored_from_checkpoint = True

def _next_op_id(self):
i = self._next_op_id_counter
self._next_op_id_counter += 1
return i

def build(self):
"""Build the pipeline.
Expand Down Expand Up @@ -2101,37 +2152,26 @@ def get_op_input_edges(op) -> List[DataNode]:
else:
yield inp

def get_op_outputs_num():
# BSF traverse the graph first to learn, for each reachable operator in the graph,
# how many data-nodes/edges the operator contributes to
# (i.e. the number of outputs of the operator instance)
op_outputs_num = {}
edges = deque(output_nodes)
while edges:
current_edge = edges.popleft()
source_op = get_source_op(current_edge)
if source_op.id in op_outputs_num:
op_outputs_num[source_op.id] += 1
else:
op_outputs_num[source_op.id] = 1
source_op.check_args()
edges.extend(get_op_input_edges(source_op))
return op_outputs_num

visited = set()
ops = []
edges = deque(output_nodes)
op_total_outputs_num = get_op_outputs_num()
op_visited_outputs_num = {op_id: 0 for op_id in op_total_outputs_num}
while edges:
current_edge = edges.popleft()
source_op = get_source_op(current_edge)
op_visited_outputs_num[source_op.id] += 1
# Actually visit the operator only when all the nodes it contributes to
# were already processed
if op_visited_outputs_num[source_op.id] == op_total_outputs_num[source_op.id]:
ops.append(source_op)
edges.extend(get_op_input_edges(source_op))
ops.reverse()

# Depth-first search returns the graph topologically sorted.
# We go over each operator's inputs before adding it to the list.

def visit_op(op):
if id(op) in visited:
return
visited.add(id(op))
op.check_args()
# visit conttributing inputs
for edge in get_op_input_edges(op):
visit_op(get_source_op(edge))
# add the operator to the list of contributing ops
ops.append(op)

for edge in output_nodes:
visit_op(get_source_op(edge))

return ops


Expand Down
19 changes: 8 additions & 11 deletions dali/test/python/auto_aug/test_auto_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -210,8 +210,8 @@ def pipeline():
@params(
(False, "cpu", 256),
(False, "gpu", 512),
(True, "cpu", 400),
(True, "gpu", 348),
(True, "cpu", 2000),
(True, "gpu", 2000),
)
def test_sub_policy(randomly_negate, dev, batch_size):
num_magnitude_bins = 10
Expand Down Expand Up @@ -305,9 +305,8 @@ def third(data, op_id_mag_id):
expected_counts.append(expected)
stat = chisquare(counts, expected_counts)
# assert that the magnitudes negation looks independently enough
# (0.05 <=), but also that it is not too ideal (i.e. like all
# cases happening exactly the expected number of times)
assert 0.05 <= stat.pvalue <= 0.95, f"{stat}"
# (0.01 <=)
assert 0.01 <= stat.pvalue, f"{stat}"


@params(("cpu",), ("gpu",))
Expand Down Expand Up @@ -397,7 +396,7 @@ def second_stage_only(data, op_id_mag_id):
)

policy = Policy("MyPolicy", num_magnitude_bins=num_magnitude_bins, sub_policies=sub_policies)
p = concat_aug_pipeline(batch_size=batch_size, dev=dev, policy=policy)
p = concat_aug_pipeline(batch_size=batch_size, dev=dev, policy=policy, seed=1234)
p.build()

for _ in range(5):
Expand All @@ -415,10 +414,8 @@ def second_stage_only(data, op_id_mag_id):
actual.append(actual_counts[mags])
expected.append(expected_counts[mags])
stat = chisquare(actual, expected)
# assert that the magnitudes negation looks independently enough
# (0.05 <=), but also that it is not too ideal (i.e. like all
# cases happening exactly the expected number of times)
assert 0.05 <= stat.pvalue <= 0.95, f"{stat}"
# assert that the magnitudes negation looks independently enough (0.01 <=)
assert 0.01 <= stat.pvalue, f"{stat}"


def test_policy_presentation():
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/auto_aug/test_rand_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -312,7 +312,7 @@ def pipeline():
actual.append(actual_count[out])
expected.append(expected_counts[out])
stat = chisquare(actual, expected)
assert 0.01 <= stat.pvalue <= 0.99, f"{stat} {actual} {expected}"
assert 0.01 <= stat.pvalue, f"{stat} {actual} {expected}"


def test_wrong_params_fail():
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/auto_aug/test_trivial_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -236,4 +236,4 @@ def pipeline():
stat = chisquare(actual, expected)
stats.append(stat)
mean_p_val = sum(stat.pvalue for stat in stats) / len(stats)
assert 0.05 <= mean_p_val <= 0.95, f"{mean_p_val} {stat} {actual} {expected}"
assert 0.01 <= mean_p_val, f"{mean_p_val} {stat} {actual} {expected}"
Loading

0 comments on commit 9dfdeac

Please sign in to comment.