Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default schema and use it in OpSpec argument queries. #5500

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dali/pipeline/operator/op_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ const OpSchema *SchemaRegistry::TryGetSchema(const std::string &name) {
return it != schema_map.end() ? &it->second : nullptr;
}

const OpSchema &OpSchema::Default() {
static OpSchema default_schema("");
return default_schema;
}

OpSchema::OpSchema(const std::string &name) : name_(name) {
// Process the module path and operator name
Expand Down
5 changes: 5 additions & 0 deletions dali/pipeline/operator/op_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ class DLL_PUBLIC OpSchema {

DLL_PUBLIC inline ~OpSchema() = default;

/**
* @brief Returns an empty schema, with only internal arguments
*/
DLL_PUBLIC static const OpSchema &Default();

/**
* @brief Returns the schema name of this operator.
*/
Expand Down
2 changes: 1 addition & 1 deletion dali/pipeline/operator/op_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ OpSpec& OpSpec::AddOutput(const string &name, const string &device) {
OpSpec& OpSpec::AddArgumentInput(const string &arg_name, const string &inp_name) {
DALI_ENFORCE(!this->HasArgument(arg_name), make_string(
"Argument '", arg_name, "' is already specified."));
const OpSchema& schema = GetSchema();
const OpSchema& schema = GetSchemaOrDefault();
DALI_ENFORCE(schema.HasArgument(arg_name),
make_string("Argument '", arg_name, "' is not supported by operator `",
GetOpDisplayName(*this, true), "`."));
Expand Down
23 changes: 19 additions & 4 deletions dali/pipeline/operator/op_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,26 @@ class DLL_PUBLIC OpSpec {
schema_ = schema_name_.empty() ? nullptr : SchemaRegistry::TryGetSchema(schema_name_);
}

/**
* @brief Sets the schema of the Operator.
*/
DLL_PUBLIC inline void SetSchema(OpSchema *schema) {
schema_ = schema;
if (schema)
schema_name_ = schema->name();
else
schema_name_ = {};
}

DLL_PUBLIC inline const OpSchema &GetSchema() const {
DALI_ENFORCE(schema_ != nullptr, "No schema found for operator \"" + SchemaName() + "\"");
return *schema_;
}

DLL_PUBLIC inline const OpSchema &GetSchemaOrDefault() const {
return schema_ ? *schema_ : OpSchema::Default();
}

/**
* @brief Add an argument with the given name and value.
*/
Expand Down Expand Up @@ -469,7 +484,7 @@ inline T OpSpec::GetArgumentImpl(
return static_cast<T>(arg.Get<S>());
} else {
// Argument wasn't present locally, get the default from the associated schema
const OpSchema& schema = GetSchema();
const OpSchema& schema = GetSchemaOrDefault();
return static_cast<T>(schema.GetDefaultValueForArgument<S>(name));
}
}
Expand All @@ -495,7 +510,7 @@ inline bool OpSpec::TryGetArgumentImpl(
}
// Search for the argument locally
auto arg_it = argument_idxs_.find(name);
const OpSchema& schema = GetSchema();
const OpSchema& schema = GetSchemaOrDefault();
if (arg_it != argument_idxs_.end()) {
// Found locally - return
Argument &arg = *arguments_[arg_it->second];
Expand Down Expand Up @@ -527,7 +542,7 @@ inline std::vector<T> OpSpec::GetRepeatedArgumentImpl(const string &name) const
return detail::convert_vector<T>(arg.Get<V>());
} else {
// Argument wasn't present locally, get the default from the associated schema
const OpSchema& schema = GetSchema();
const OpSchema& schema = GetSchemaOrDefault();
return detail::convert_vector<T>(schema.GetDefaultValueForArgument<V>(name));
}
}
Expand All @@ -537,7 +552,7 @@ inline bool OpSpec::TryGetRepeatedArgumentImpl(C &result, const string &name) co
using V = std::vector<S>;
// Search for the argument locally
auto arg_it = argument_idxs_.find(name);
const OpSchema& schema = GetSchema();
const OpSchema& schema = GetSchemaOrDefault();
if (arg_it != argument_idxs_.end()) {
// Found locally - return
Argument &arg = *arguments_[arg_it->second];
Expand Down
6 changes: 6 additions & 0 deletions dali/pipeline/operator/op_spec_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,5 +467,11 @@ TEST(TestOpSpec, Lookup) {
EXPECT_EQ(spec.ArgumentInputName(2), "zero");
}

TEST(TestOpSpec, EmptySchema) {
OpSpec spec("nonexistent_schema");
EXPECT_THROW(spec.GetSchema(), std::runtime_error);
EXPECT_EQ(spec.GetArgument<std::string>("device"), "cpu");
EXPECT_EQ(spec.GetArgument<std::string>("_module"), "nvidia.dali.ops");
}

} // namespace dali
Loading