Skip to content

Commit

Permalink
Add default schema and use it in OpSpec argument queries. (#5500)
Browse files Browse the repository at this point in the history
Querying an OpSpec for arguments no longer requires a valid schema.

---------

Signed-off-by: Michal Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Jun 6, 2024
1 parent 50715e1 commit df89b7e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 5 deletions.
4 changes: 4 additions & 0 deletions dali/pipeline/operator/op_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ const OpSchema *SchemaRegistry::TryGetSchema(const std::string &name) {
return it != schema_map.end() ? &it->second : nullptr;
}

const OpSchema &OpSchema::Default() {
static OpSchema default_schema("");
return default_schema;
}

OpSchema::OpSchema(const std::string &name) : name_(name) {
// Process the module path and operator name
Expand Down
5 changes: 5 additions & 0 deletions dali/pipeline/operator/op_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ class DLL_PUBLIC OpSchema {

DLL_PUBLIC inline ~OpSchema() = default;

/**
* @brief Returns an empty schema, with only internal arguments
*/
DLL_PUBLIC static const OpSchema &Default();

/**
* @brief Returns the schema name of this operator.
*/
Expand Down
2 changes: 1 addition & 1 deletion dali/pipeline/operator/op_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ OpSpec& OpSpec::AddOutput(const string &name, const string &device) {
OpSpec& OpSpec::AddArgumentInput(const string &arg_name, const string &inp_name) {
DALI_ENFORCE(!this->HasArgument(arg_name), make_string(
"Argument '", arg_name, "' is already specified."));
const OpSchema& schema = GetSchema();
const OpSchema& schema = GetSchemaOrDefault();
DALI_ENFORCE(schema.HasArgument(arg_name),
make_string("Argument '", arg_name, "' is not supported by operator `",
GetOpDisplayName(*this, true), "`."));
Expand Down
23 changes: 19 additions & 4 deletions dali/pipeline/operator/op_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,26 @@ class DLL_PUBLIC OpSpec {
schema_ = schema_name_.empty() ? nullptr : SchemaRegistry::TryGetSchema(schema_name_);
}

/**
* @brief Sets the schema of the Operator.
*/
DLL_PUBLIC inline void SetSchema(OpSchema *schema) {
schema_ = schema;
if (schema)
schema_name_ = schema->name();
else
schema_name_ = {};
}

DLL_PUBLIC inline const OpSchema &GetSchema() const {
DALI_ENFORCE(schema_ != nullptr, "No schema found for operator \"" + SchemaName() + "\"");
return *schema_;
}

DLL_PUBLIC inline const OpSchema &GetSchemaOrDefault() const {
return schema_ ? *schema_ : OpSchema::Default();
}

/**
* @brief Add an argument with the given name and value.
*/
Expand Down Expand Up @@ -469,7 +484,7 @@ inline T OpSpec::GetArgumentImpl(
return static_cast<T>(arg.Get<S>());
} else {
// Argument wasn't present locally, get the default from the associated schema
const OpSchema& schema = GetSchema();
const OpSchema& schema = GetSchemaOrDefault();
return static_cast<T>(schema.GetDefaultValueForArgument<S>(name));
}
}
Expand All @@ -495,7 +510,7 @@ inline bool OpSpec::TryGetArgumentImpl(
}
// Search for the argument locally
auto arg_it = argument_idxs_.find(name);
const OpSchema& schema = GetSchema();
const OpSchema& schema = GetSchemaOrDefault();
if (arg_it != argument_idxs_.end()) {
// Found locally - return
Argument &arg = *arguments_[arg_it->second];
Expand Down Expand Up @@ -527,7 +542,7 @@ inline std::vector<T> OpSpec::GetRepeatedArgumentImpl(const string &name) const
return detail::convert_vector<T>(arg.Get<V>());
} else {
// Argument wasn't present locally, get the default from the associated schema
const OpSchema& schema = GetSchema();
const OpSchema& schema = GetSchemaOrDefault();
return detail::convert_vector<T>(schema.GetDefaultValueForArgument<V>(name));
}
}
Expand All @@ -537,7 +552,7 @@ inline bool OpSpec::TryGetRepeatedArgumentImpl(C &result, const string &name) co
using V = std::vector<S>;
// Search for the argument locally
auto arg_it = argument_idxs_.find(name);
const OpSchema& schema = GetSchema();
const OpSchema& schema = GetSchemaOrDefault();
if (arg_it != argument_idxs_.end()) {
// Found locally - return
Argument &arg = *arguments_[arg_it->second];
Expand Down
6 changes: 6 additions & 0 deletions dali/pipeline/operator/op_spec_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,5 +467,11 @@ TEST(TestOpSpec, Lookup) {
EXPECT_EQ(spec.ArgumentInputName(2), "zero");
}

TEST(TestOpSpec, EmptySchema) {
OpSpec spec("nonexistent_schema");
EXPECT_THROW(spec.GetSchema(), std::runtime_error);
EXPECT_EQ(spec.GetArgument<std::string>("device"), "cpu");
EXPECT_EQ(spec.GetArgument<std::string>("_module"), "nvidia.dali.ops");
}

} // namespace dali

0 comments on commit df89b7e

Please sign in to comment.