From 1c6958f7072394ef0d6d1da1fd870c6a068907b7 Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Wed, 5 Jun 2024 13:11:46 +0200 Subject: [PATCH 1/2] Add default schema and use it in OpSpec argument queries. Signed-off-by: Michal Zientkiewicz --- dali/pipeline/operator/op_schema.cc | 4 ++++ dali/pipeline/operator/op_schema.h | 5 +++++ dali/pipeline/operator/op_spec.cc | 2 +- dali/pipeline/operator/op_spec.h | 23 +++++++++++++++++++---- dali/pipeline/operator/op_spec_test.cc | 5 +++++ 5 files changed, 34 insertions(+), 5 deletions(-) diff --git a/dali/pipeline/operator/op_schema.cc b/dali/pipeline/operator/op_schema.cc index 87977888c2..2b3fdff82c 100644 --- a/dali/pipeline/operator/op_schema.cc +++ b/dali/pipeline/operator/op_schema.cc @@ -54,6 +54,10 @@ const OpSchema *SchemaRegistry::TryGetSchema(const std::string &name) { return it != schema_map.end() ? &it->second : nullptr; } +const OpSchema &OpSchema::Default() { + static OpSchema default_schema(""); + return default_schema; +} OpSchema::OpSchema(const std::string &name) : name_(name) { // Process the module path and operator name diff --git a/dali/pipeline/operator/op_schema.h b/dali/pipeline/operator/op_schema.h index 0515363484..d3a0f0348d 100644 --- a/dali/pipeline/operator/op_schema.h +++ b/dali/pipeline/operator/op_schema.h @@ -81,6 +81,11 @@ class DLL_PUBLIC OpSchema { DLL_PUBLIC inline ~OpSchema() = default; + /** + * @brief Returns an empty schema, with only internal arguments + */ + DLL_PUBLIC static const OpSchema &Default(); + /** * @brief Returns the schema name of this operator. */ diff --git a/dali/pipeline/operator/op_spec.cc b/dali/pipeline/operator/op_spec.cc index 33c7ee7ec3..ed9b6cbc16 100644 --- a/dali/pipeline/operator/op_spec.cc +++ b/dali/pipeline/operator/op_spec.cc @@ -51,7 +51,7 @@ OpSpec& OpSpec::AddOutput(const string &name, const string &device) { OpSpec& OpSpec::AddArgumentInput(const string &arg_name, const string &inp_name) { DALI_ENFORCE(!this->HasArgument(arg_name), make_string( "Argument '", arg_name, "' is already specified.")); - const OpSchema& schema = GetSchema(); + const OpSchema& schema = GetSchemaOrDefault(); DALI_ENFORCE(schema.HasArgument(arg_name), make_string("Argument '", arg_name, "' is not supported by operator `", GetOpDisplayName(*this, true), "`.")); diff --git a/dali/pipeline/operator/op_spec.h b/dali/pipeline/operator/op_spec.h index dda4a2690c..30bcd53898 100644 --- a/dali/pipeline/operator/op_spec.h +++ b/dali/pipeline/operator/op_spec.h @@ -82,11 +82,26 @@ class DLL_PUBLIC OpSpec { schema_ = schema_name_.empty() ? nullptr : SchemaRegistry::TryGetSchema(schema_name_); } + /** + * @brief Sets the schema of the Operator. + */ + DLL_PUBLIC inline void SetSchema(OpSchema *schema) { + schema_ = schema; + if (schema) + schema_name_ = schema->name(); + else + schema_name_ = {}; + } + DLL_PUBLIC inline const OpSchema &GetSchema() const { DALI_ENFORCE(schema_ != nullptr, "No schema found for operator \"" + SchemaName() + "\""); return *schema_; } + DLL_PUBLIC inline const OpSchema &GetSchemaOrDefault() const { + return schema_ ? *schema_ : OpSchema::Default(); + } + /** * @brief Add an argument with the given name and value. */ @@ -469,7 +484,7 @@ inline T OpSpec::GetArgumentImpl( return static_cast(arg.Get()); } else { // Argument wasn't present locally, get the default from the associated schema - const OpSchema& schema = GetSchema(); + const OpSchema& schema = GetSchemaOrDefault(); return static_cast(schema.GetDefaultValueForArgument(name)); } } @@ -495,7 +510,7 @@ inline bool OpSpec::TryGetArgumentImpl( } // Search for the argument locally auto arg_it = argument_idxs_.find(name); - const OpSchema& schema = GetSchema(); + const OpSchema& schema = GetSchemaOrDefault(); if (arg_it != argument_idxs_.end()) { // Found locally - return Argument &arg = *arguments_[arg_it->second]; @@ -527,7 +542,7 @@ inline std::vector OpSpec::GetRepeatedArgumentImpl(const string &name) const return detail::convert_vector(arg.Get()); } else { // Argument wasn't present locally, get the default from the associated schema - const OpSchema& schema = GetSchema(); + const OpSchema& schema = GetSchemaOrDefault(); return detail::convert_vector(schema.GetDefaultValueForArgument(name)); } } @@ -537,7 +552,7 @@ inline bool OpSpec::TryGetRepeatedArgumentImpl(C &result, const string &name) co using V = std::vector; // Search for the argument locally auto arg_it = argument_idxs_.find(name); - const OpSchema& schema = GetSchema(); + const OpSchema& schema = GetSchemaOrDefault(); if (arg_it != argument_idxs_.end()) { // Found locally - return Argument &arg = *arguments_[arg_it->second]; diff --git a/dali/pipeline/operator/op_spec_test.cc b/dali/pipeline/operator/op_spec_test.cc index 3531603b08..5e3790738d 100644 --- a/dali/pipeline/operator/op_spec_test.cc +++ b/dali/pipeline/operator/op_spec_test.cc @@ -467,5 +467,10 @@ TEST(TestOpSpec, Lookup) { EXPECT_EQ(spec.ArgumentInputName(2), "zero"); } +TEST(TestOpSpec, EmptySchema) { + OpSpec spec("dummy"); + EXPECT_EQ(spec.GetArgument("device"), "cpu"); + EXPECT_EQ(spec.GetArgument("_module"), "nvidia.dali.ops"); +} } // namespace dali From 9e02019d229e5617c3d6c4da5e16c237b994d355 Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Wed, 5 Jun 2024 13:21:40 +0200 Subject: [PATCH 2/2] Improve test. Signed-off-by: Michal Zientkiewicz --- dali/pipeline/operator/op_spec_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dali/pipeline/operator/op_spec_test.cc b/dali/pipeline/operator/op_spec_test.cc index 5e3790738d..6f0728aeba 100644 --- a/dali/pipeline/operator/op_spec_test.cc +++ b/dali/pipeline/operator/op_spec_test.cc @@ -468,7 +468,8 @@ TEST(TestOpSpec, Lookup) { } TEST(TestOpSpec, EmptySchema) { - OpSpec spec("dummy"); + OpSpec spec("nonexistent_schema"); + EXPECT_THROW(spec.GetSchema(), std::runtime_error); EXPECT_EQ(spec.GetArgument("device"), "cpu"); EXPECT_EQ(spec.GetArgument("_module"), "nvidia.dali.ops"); }