Skip to content

Commit

Permalink
Add fn.zeros, fn.zeros_like, fn.ones, fn.ones_like, fn.full and fn.fu…
Browse files Browse the repository at this point in the history
…ll_like (#5505)

Creates a set of operators analogous to np.zeros, np.zeros_like, np.ones, np.ones_like, np.full, and np.full_like

Signed-off-by: Joaquin Anton <[email protected]>
  • Loading branch information
jantonguirao committed Jun 11, 2024
1 parent d6b29c7 commit 2a1c2ac
Show file tree
Hide file tree
Showing 7 changed files with 517 additions and 0 deletions.
125 changes: 125 additions & 0 deletions dali/operators/generic/constant_value.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dali/operators/generic/constant_value.h"
#include <vector>
#include "dali/core/convert.h"
#include "dali/pipeline/operator/op_schema.h"

namespace dali {

template <>
void ConstantValue<CPUBackend>::RunImpl(Workspace &ws) {
auto &output = ws.Output<CPUBackend>(0);
const auto& out_shape = output.shape();
int nsamples = out_shape.size();
auto dtype = output.type();
auto &tp = ws.GetThreadPool();
if (has_fill_value_) {
auto &fill_value = ws.Input<CPUBackend>(kValueInputIdx);
const auto &fill_value_sh = fill_value.shape();
TYPE_SWITCH(fill_value.type(), type2id, FillValueType, (DALI_CONSTANT_VALUE_TYPES), (
TYPE_SWITCH(dtype, type2id, OutputType, (DALI_CONSTANT_VALUE_TYPES), (
for (int s = 0; s < nsamples; s++) {
tp.AddWork([&, s](int thread_idx) {
const auto* fill_value_data = fill_value.tensor<FillValueType>(s);
auto* out_data = output.mutable_tensor<OutputType>(s);
auto out_sz = out_shape.tensor_size(s);
auto fill_value_sz = fill_value_sh.tensor_size(s);
for (int64_t i = 0; i < out_sz; i++) {
out_data[i] = ConvertSat<OutputType>(fill_value_data[i % fill_value_sz]);
}
});
}
tp.RunAll();
), ( // NOLINT
DALI_FAIL(
make_string("Data type ", dtype, " is currently not supported. "
"Supported types are : ", ListTypeNames<DALI_CONSTANT_VALUE_TYPES>()));
)); // NOLINT
), ( // NOLINT
DALI_FAIL(
make_string("Data type ", fill_value.type(), " is currently not supported. "
"Supported types are : ", ListTypeNames<DALI_CONSTANT_VALUE_TYPES>()));
)); // NOLINT
} else {
TYPE_SWITCH(dtype, type2id, T, (DALI_CONSTANT_VALUE_TYPES), (
T value = ConvertSat<T>(const_value_);
for (int s = 0; s < nsamples; s++) {
tp.AddWork([&, value, s](int thread_idx) {
auto* out_data = output.mutable_tensor<T>(s);
auto out_sz = out_shape.tensor_size(s);
std::fill(out_data, out_data + out_sz, value);
});
}
tp.RunAll();
), ( // NOLINT
DALI_FAIL(make_string("Data type ", dtype, " is currently not supported. "
"Supported types are : ", ListTypeNames<DALI_CONSTANT_VALUE_TYPES>()));
)); // NOLINT
}
}

DALI_SCHEMA(Full)
.DocStr(R"code(Returns new data of given shape and type, filled with a fill value.)code")
.NumInput(1)
.InputDox(0, "fill_value", "TensorList", R"code(The fill value.)code")
.NumOutput(1)
.AddOptionalArg<std::vector<int>>("shape", R"code(Shape of the output data.)code", nullptr,
true);
DALI_REGISTER_OPERATOR(Full, Full<CPUBackend>, CPU);

DALI_SCHEMA(FullLike)
.DocStr(R"code(Returns new data with the same shape and type as the input data, filled with a `fill_value`.)code")
.NumInput(2)
.InputDox(0, "data_like", "TensorList", R"code(The input data value to copy the shape and type from.)code")
.InputDox(1, "fill_value", "TensorList", R"code(The fill value.)code")
.NumOutput(1);
DALI_REGISTER_OPERATOR(FullLike, FullLike<CPUBackend>, CPU);

DALI_SCHEMA(Zeros)
.DocStr(R"code(Returns new data of given shape and type, filled with zeros.)code")
.NumInput(0)
.NumOutput(1)
.AddOptionalArg<std::vector<int>>("shape", R"code(Shape of the output data.)code", nullptr,
true)
.AddOptionalTypeArg("dtype", R"code(Output data type.)code", DALI_INT32);
DALI_REGISTER_OPERATOR(Zeros, Zeros<CPUBackend>, CPU);

DALI_SCHEMA(ZerosLike)
.DocStr(R"code(Returns new data with the same shape and type as the input array, filled with zeros.)code")
.NumInput(1)
.InputDox(0, "data_like", "TensorList", R"code(The input data value to copy the shape and type from.)code")
.NumOutput(1)
.AddOptionalTypeArg("dtype", R"code(Overrides the output data type.)code", DALI_INT32);
DALI_REGISTER_OPERATOR(ZerosLike, ZerosLike<CPUBackend>, CPU);

DALI_SCHEMA(Ones)
.DocStr(R"code(Returns new data of given shape and type, filled with ones.)code")
.NumInput(0)
.NumOutput(1)
.AddOptionalArg<std::vector<int>>("shape", R"code(Shape of the output data.)code", nullptr,
true)
.AddOptionalTypeArg("dtype", R"code(Output data type.)code", DALI_INT32);
DALI_REGISTER_OPERATOR(Ones, Ones<CPUBackend>, CPU);

DALI_SCHEMA(OnesLike)
.DocStr(R"code(Returns new data with the same shape and type as the input array, filled with ones.)code")
.NumInput(1)
.InputDox(0, "data_like", "TensorList", R"code(The input data value to copy the shape and type from.)code")
.NumOutput(1)
.AddOptionalTypeArg("dtype", R"code(Overrides the output data type.)code", DALI_INT32);
DALI_REGISTER_OPERATOR(OnesLike, OnesLike<CPUBackend>, CPU);

} // namespace dali
173 changes: 173 additions & 0 deletions dali/operators/generic/constant_value.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_OPERATORS_GENERIC_CONSTANT_VALUE_H_
#define DALI_OPERATORS_GENERIC_CONSTANT_VALUE_H_

#include <vector>
#include "dali/core/static_switch.h"
#include "dali/core/tensor_shape_print.h"
#include "dali/pipeline/operator/checkpointing/stateless_operator.h"
#include "dali/core/float16.h"

#define DALI_CONSTANT_VALUE_TYPES \
uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t, float, float16, \
double, bool

namespace dali {

template <typename Backend>
class ConstantValue : public StatelessOperator<Backend> {
public:
explicit ConstantValue(const OpSpec &spec, bool has_fill_value = false,
bool is_shape_like = false)
: StatelessOperator<Backend>(spec),
has_fill_value_(has_fill_value),
is_shape_like_(is_shape_like),
has_shape_(spec.ArgumentDefined("shape")),
has_dtype_(spec.ArgumentDefined("dtype")) {
dtype_ = has_dtype_ ? spec.GetArgument<DALIDataType>("dtype") : DALI_INT32;
}

int GetBatchSize(const Workspace &ws) const {
if (is_shape_like_)
return ws.Input<Backend>(kShapeLikeInputIdx).shape().size();
else
return ws.GetRequestedBatchSize(0);
}

bool CanInferOutputs() const override { return true; }

bool CanBroadcastShapes(span<int64_t> shape1, span<int64_t> shape2) {
size_t len1 = shape1.size();
size_t len2 = shape2.size();
size_t max_len = std::max(len1, len2);
for (size_t i = 0; i < max_len; ++i) {
// Get the dimensions from each shape, defaulting to 1 if out of bounds
int dim1 = (i < len1) ? shape1[len1 - 1 - i] : 1;
int dim2 = (i < len2) ? shape2[len2 - 1 - i] : 1;
// Check if the dimensions are compatible
if (dim1 != dim2 && dim1 != 1 && dim2 != 1) {
return false;
}
}
return true;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
int nsamples = GetBatchSize(ws);
output_desc.resize(1);
auto &dtype = output_desc[0].type;
auto &shape = output_desc[0].shape;
dtype = is_shape_like_ && !has_dtype_ ? ws.Input<Backend>(kShapeLikeInputIdx).type() : dtype_;

if (is_shape_like_) {
shape = ws.Input<Backend>(kShapeLikeInputIdx).shape();
} else if (has_shape_) {
GetShapeArgument(shape, spec_, "shape", ws, nsamples);
} else {
shape = uniform_list_shape(nsamples, TensorShape<0>{});
}

if (has_fill_value_) {
auto& fill_value = ws.Input<Backend>(kValueInputIdx);
auto fill_value_shape = fill_value.shape();
auto fill_value_dtype = fill_value.type();
int new_ndim = shape.sample_dim() + fill_value_shape.sample_dim();
for (int i = 0; i < nsamples; i++) {
auto orig_shape = shape.tensor_shape_span(i);
auto fill_value_sh = fill_value_shape.tensor_shape_span(i);
if (!CanBroadcastShapes(orig_shape, fill_value_sh)) {
DALI_FAIL(make_string("Shapes ", shape.tensor_shape(i), " and ",
fill_value_shape.tensor_shape(i), " can't be broadcast."));
}
}
if (!has_dtype_ && !is_shape_like_) {
dtype = fill_value_dtype;
}
}
return true;
}

void SetConstValue(int value) {
has_const_value_ = true;
const_value_ = value;
}

void RunImpl(Workspace &ws) override;

protected:
using Operator<Backend>::spec_;
using Operator<Backend>::max_batch_size_;
const bool has_fill_value_;
const bool is_shape_like_;
bool has_shape_, has_dtype_;
DALIDataType dtype_;

bool has_const_value_ = false;
int const_value_ = 0;

const int kShapeLikeInputIdx = is_shape_like_ ? 0 : -1;
const int kValueInputIdx = is_shape_like_ ? 1 : 0;
};

template <typename Backend>
class Full : public ConstantValue<Backend> {
public:
explicit Full(const OpSpec &spec): ConstantValue<Backend>(spec, true, false) {
}
};

template <typename Backend>
class FullLike : public ConstantValue<Backend> {
public:
explicit FullLike(const OpSpec &spec): ConstantValue<Backend>(spec, true, true) {
}
};

template <typename Backend>
class Zeros : public ConstantValue<Backend> {
public:
explicit Zeros(const OpSpec &spec): ConstantValue<Backend>(spec, false, false) {
ConstantValue<Backend>::SetConstValue(0);
}
};

template <typename Backend>
class ZerosLike : public ConstantValue<Backend> {
public:
explicit ZerosLike(const OpSpec &spec): ConstantValue<Backend>(spec, false, true) {
ConstantValue<Backend>::SetConstValue(0);
}
};

template <typename Backend>
class Ones : public ConstantValue<Backend> {
public:
explicit Ones(const OpSpec &spec): ConstantValue<Backend>(spec, false, false) {
ConstantValue<Backend>::SetConstValue(1);
}
};

template <typename Backend>
class OnesLike : public ConstantValue<Backend> {
public:
explicit OnesLike(const OpSpec &spec): ConstantValue<Backend>(spec, false, true) {
ConstantValue<Backend>::SetConstValue(1);
}
};

} // namespace dali

#endif // DALI_OPERATORS_GENERIC_CONSTANT_VALUE_H_
18 changes: 18 additions & 0 deletions dali/test/python/checkpointing/test_dali_stateless_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,21 @@ def test_dilate_stateless(device):
@stateless_signed_off("experimental.erode")
def test_erode_stateless(device):
check_single_input(fn.experimental.erode, device)


@params("cpu")
@stateless_signed_off("zeros", "ones", "full", "zeros_like", "ones_like", "full_like")
def test_full_operator_family(device):
@pipeline_def(enable_checkpointing=True)
def pipeline_factory():
sh = np.array([2, 3], dtype=np.int32)
fill_value = np.array([1.0, 0.4, 3.0], dtype=np.float32)
zeros = fn.zeros(shape=sh)
ones = fn.ones(shape=sh)
full = fn.full(fill_value, shape=sh)
zeros_like = fn.zeros_like(zeros)
ones_like = fn.ones_like(zeros)
full_like = fn.full_like(zeros, fill_value)
return zeros, ones, full, zeros_like, ones_like, full_like

check_is_pipeline_stateless(pipeline_factory)
67 changes: 67 additions & 0 deletions dali/test/python/operator_1/test_constant_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from nvidia.dali import pipeline_def, fn


def run(op):
@pipeline_def(batch_size=1, num_threads=3, device_id=0)
def pipe0():
return op

p = pipe0()
p.build()
return np.array(p.run()[0][0])


def test_zeros():
sh = (2, 3)
np.testing.assert_array_equal(run(fn.zeros(shape=sh)), np.zeros(shape=sh))


def test_zeros_like():
sh = (2, 3)
arr = np.ones(sh)
np.testing.assert_array_almost_equal(run(fn.zeros_like(arr)), np.zeros_like(arr))


def test_ones():
sh = (2, 3)
np.testing.assert_array_almost_equal(run(fn.ones(shape=sh)), np.ones(shape=sh))


def test_ones_like():
sh = (2, 3)
arr = np.ones(sh)
np.testing.assert_array_almost_equal(run(fn.ones_like(arr)), np.ones_like(arr))


def test_full():
sh = (2, 3, 4)
fill_value_sh = (3, 4)
fill_value_arr = np.random.uniform(size=fill_value_sh)
np.testing.assert_array_almost_equal(
run(fn.full(fill_value_arr, shape=sh)), np.full(sh, fill_value_arr)
)


def test_full_like():
sh = (2, 3, 4)
fill_value_sh = (3, 4)
arr = np.random.uniform(size=sh)
fill_value_arr = np.random.uniform(size=fill_value_sh)
np.testing.assert_array_almost_equal(
run(fn.full_like(arr, fill_value_arr)), np.full_like(arr, fill_value_arr)
)
Loading

0 comments on commit 2a1c2ac

Please sign in to comment.