From a0f9683dec091247578d02558b5af8c7d21ad0da Mon Sep 17 00:00:00 2001 From: Rafal Banas Date: Fri, 28 Jun 2024 12:53:27 +0200 Subject: [PATCH] Add warp_perspective operator Signed-off-by: Rafal Banas --- cmake/Dependencies.common.cmake | 2 +- dali/operators/image/remap/CMakeLists.txt | 6 +- .../image/remap/cvcuda/CMakeLists.txt | 18 + .../image/remap/cvcuda/warp_perspective.cc | 213 ++++++++++ dali/operators/nvcvop/nvcvop.cc | 32 +- dali/operators/nvcvop/nvcvop.h | 27 +- .../test_dali_stateless_operators.py | 6 + .../operator_2/test_warp_perspective.py | 372 ++++++++++++++++++ dali/test/python/test_dali_cpu_only.py | 1 + .../python/test_dali_variable_batch_size.py | 2 + 10 files changed, 670 insertions(+), 9 deletions(-) create mode 100644 dali/operators/image/remap/cvcuda/CMakeLists.txt create mode 100644 dali/operators/image/remap/cvcuda/warp_perspective.cc create mode 100644 dali/test/python/operator_2/test_warp_perspective.py diff --git a/cmake/Dependencies.common.cmake b/cmake/Dependencies.common.cmake index 8dd91fd541..effded9b90 100644 --- a/cmake/Dependencies.common.cmake +++ b/cmake/Dependencies.common.cmake @@ -264,7 +264,7 @@ if (BUILD_CVCUDA) set(DALI_BUILD_PYTHON ${BUILD_PYTHON}) set(BUILD_PYTHON OFF) # for now we use only median blur from CV-CUDA - set(CV_CUDA_SRC_PATERN medianblur median_blur morphology) + set(CV_CUDA_SRC_PATERN medianblur median_blur morphology warp) check_and_add_cmake_submodule(${PROJECT_SOURCE_DIR}/third_party/cvcuda) set(BUILD_PYTHON ${DALI_BUILD_PYTHON}) endif() diff --git a/dali/operators/image/remap/CMakeLists.txt b/dali/operators/image/remap/CMakeLists.txt index 4a89bcdd69..97660d26dd 100644 --- a/dali/operators/image/remap/CMakeLists.txt +++ b/dali/operators/image/remap/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +16,7 @@ collect_headers(DALI_INST_HDRS PARENT_SCOPE) collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE) + +if (BUILD_CVCUDA) + add_subdirectory(cvcuda) +endif() diff --git a/dali/operators/image/remap/cvcuda/CMakeLists.txt b/dali/operators/image/remap/cvcuda/CMakeLists.txt new file mode 100644 index 0000000000..74804739b2 --- /dev/null +++ b/dali/operators/image/remap/cvcuda/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Get all the source files and dump test files +collect_headers(DALI_INST_HDRS PARENT_SCOPE) +collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) +collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE) diff --git a/dali/operators/image/remap/cvcuda/warp_perspective.cc b/dali/operators/image/remap/cvcuda/warp_perspective.cc new file mode 100644 index 0000000000..0460b87dc3 --- /dev/null +++ b/dali/operators/image/remap/cvcuda/warp_perspective.cc @@ -0,0 +1,213 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "dali/core/dev_buffer.h" +#include "dali/core/static_switch.h" +#include "dali/kernels/common/utils.h" +#include "dali/kernels/dynamic_scratchpad.h" +#include "dali/pipeline/operator/arg_helper.h" +#include "dali/pipeline/operator/checkpointing/stateless_operator.h" +#include "dali/pipeline/operator/operator.h" + +#include "dali/operators/nvcvop/nvcvop.h" + +namespace dali { + + +DALI_SCHEMA(experimental__WarpPerspective) + .DocStr(R"doc( +TODO + )doc") + .NumInput(1, 2) + .InputDox(0, "input", "TensorList of uint8, uint16, int16 or float", + "Input data. Must be images in HWC or CHW layout, or a sequence of those.") + .InputDox(1, "matrix_data", "1D TensorList of float", + "Transformation matrix data. Should be used to pass the GPU data. " + "For CPU data, the `matrix` argument should be used.") + .NumOutput(1) + .InputLayout(0, {"HW", "HWC", "FHWC", "CHW", "FCHW"}) + .AddOptionalArg("size", + R"code(Output size, in pixels/points. + +The channel dimension should be excluded (for example, for RGB images, +specify ``(480,640)``, not ``(480,640,3)``. +)code", + std::vector({}), true) + .AddOptionalArg("matrix", + R"doc( + Perspective transform mapping of destination to source coordinates. + If `inverse_map` argument is set to true, the matrix is interpreted + as a source to destination coordinates mapping. + +It is equivalent to OpenCV's ``warpAffine`` operation with the ``inverse_map`` argument being +analog to the ``WARP_INVERSE_MAP`` flag. + +.. note:: + Instead of this argument, the operator can take a second positional input, in which + case the matrix can be placed on the GPU.)doc", + std::vector({}), true, true) + .AddOptionalArg("border_mode", + "Border mode to be used when accessing elements outside input image.", + "constant") + .AddOptionalArg("interp_type", "Interpolation method.", "nearest") + .AddOptionalArg("fill_value", + "Value used to fill areas that are outside the source image when the " + "\"constant\" border_mode is chosen.", + std::vector({})) + .AddOptionalArg("inverse_map", "Inverse perspective transform matrix", false); + + +class WarpPerspective : public nvcvop::NVCVSequenceOperator { + public: + explicit WarpPerspective(const OpSpec &spec) + : nvcvop::NVCVSequenceOperator(spec), + shape_arg_(spec.GetArgument>("size")), + border_mode_(nvcvop::GetBorderMode(spec.GetArgument("border_mode"))), + interp_type_(nvcvop::GetInterpolationType(spec.GetArgument("interp_type"))), + inverse_map_(spec.GetArgument("inverse_map")), + fill_value_arg_(spec.GetArgument>("fill_value")) { + matrix_data_.SetContiguity(BatchContiguity::Contiguous); + } + + bool ShouldExpandChannels(int input_idx) const override { + return true; + } + + bool CanInferOutputs() const override { + return true; + } + + float4 GetFillValue(int channels) const { + if (fill_value_arg_.size() > 1) { + if (channels > 0) { + if (channels == static_cast(fill_value_arg_.size())) { + float4 fill_value{0, 0, 0, 0}; + memcpy(&fill_value, fill_value_arg_.data(), fill_value_arg_.size() * sizeof(float)); + return fill_value; + } else { + DALI_FAIL(make_string( + "Number of values provided as a fill_value should match the number of channels.\n" + "Number of channels: ", + channels, ". Number of values provided: ", fill_value_arg_.size(), ".")); + } + } else { + DALI_FAIL("Only scalar fill_value can be provided when processing data in planar layout."); + } + } else if (fill_value_arg_.size() == 1) { + auto fv = fill_value_arg_[0]; + float4 fill_value{fv, fv, fv, fv}; + return fill_value; + } else { + return float4{0, 0, 0, 0}; + } + } + + void ValidateTypes(const Workspace &ws) const { + auto inp_type = ws.Input(0).type(); + DALI_ENFORCE(inp_type == DALI_UINT8 || inp_type == DALI_INT16 || inp_type == DALI_UINT16 || + inp_type == DALI_FLOAT, + "The operator accepts the following input types: " + "uint8, int16, uint16, float."); + if (ws.NumInput() > 1) { + auto mat_type = ws.Input(1).type(); + DALI_ENFORCE(mat_type == DALI_FLOAT, + "Transformation matrix can be provided only as float values"); + } + } + + bool SetupImpl(std::vector &output_desc, const Workspace &ws) override { + ValidateTypes(ws); + const auto &input = ws.Input(0); + auto input_shape = input.shape(); + auto input_layout = input.GetLayout(); + output_desc.resize(1); + + auto output_shape = input_shape; + int channels = (input_layout.find('C') != -1) ? input_shape[0][input_layout.find('C')] : -1; + fill_value_ = GetFillValue(channels); + if (!shape_arg_.empty()) { + auto height = std::max(std::roundf(shape_arg_[0]), 1); + auto width = std::max(std::roundf(shape_arg_[1]), 1); + auto out_sample_shape = (channels != -1) ? TensorShape<>({height, width, channels}) : + TensorShape<>({height, width}); + output_shape = TensorListShape<>::make_uniform(input.num_samples(), out_sample_shape); + } + + output_desc[0] = {output_shape, input.type()}; + return true; + } + + void RunImpl(Workspace &ws) override { + const auto &input = ws.Input(0); + auto &output = ws.Output(0); + output.SetLayout(input.GetLayout()); + + kernels::DynamicScratchpad scratchpad({}, AccessOrder(ws.stream())); + + nvcv::Tensor matrix{}; + if (ws.NumInput() > 1) { + DALI_ENFORCE(!matrix_arg_.HasExplicitValue(), + "Matrix input and `matrix` argument should not be provided at the same time."); + Tensor matrix_tensor; + auto &matrix_input = ws.Input(1); + DALI_ENFORCE(is_uniform(matrix_input.shape()), + "Matrix data has to be a uniformly shaped batch."); + if (matrix_input.IsDenseTensor()) { + matrix_tensor = const_cast &>(matrix_input).AsTensor(); + } else { + matrix_data_.Copy(matrix_input, AccessOrder(ws.stream())); + matrix_tensor = matrix_data_.AsTensor(); + } + matrix = nvcvop::AsTensor(matrix_tensor, "NW"); + } else { + matrix = AcquireTensorArgument(ws, scratchpad, matrix_arg_, TensorShape<1>(9), + nvcvop::GetDataType(), "W"); + } + + auto input_images = GetInputBatch(ws, 0); + auto output_images = GetOutputBatch(ws, 0); + if (!warp_perspective_ || input.num_samples() > op_batch_size_) { + op_batch_size_ = std::max(op_batch_size_ * 2, input.num_samples()); + warp_perspective_.emplace(op_batch_size_); + } + int32_t flags = interp_type_; + if (inverse_map_) { + flags |= NVCV_WARP_INVERSE_MAP; + } + (*warp_perspective_)(ws.stream(), input_images, output_images, matrix, flags, border_mode_, + fill_value_); + } + + private: + USE_OPERATOR_MEMBERS(); + ArgValue matrix_arg_{"matrix", spec_}; + int op_batch_size_ = 0; + std::vector shape_arg_; + NVCVBorderType border_mode_{NVCV_BORDER_CONSTANT}; + NVCVInterpolationType interp_type_{NVCV_INTERP_NEAREST}; + bool inverse_map_{false}; + std::vector fill_value_arg_{0, 0, 0, 0}; + float4 fill_value_{0, 0, 0, 0}; + std::optional warp_perspective_{}; + TensorList matrix_data_{}; +}; + +DALI_REGISTER_OPERATOR(experimental__WarpPerspective, WarpPerspective, GPU); + +} // namespace dali diff --git a/dali/operators/nvcvop/nvcvop.cc b/dali/operators/nvcvop/nvcvop.cc index 4b63d5db2d..8abc0f9a19 100644 --- a/dali/operators/nvcvop/nvcvop.cc +++ b/dali/operators/nvcvop/nvcvop.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include "dali/operators/nvcvop/nvcvop.h" +#include + namespace dali::nvcvop { NVCVBorderType GetBorderMode(std::string_view border_mode) { @@ -34,6 +34,18 @@ NVCVBorderType GetBorderMode(std::string_view border_mode) { } } +NVCVInterpolationType GetInterpolationType(std::string_view interpolation_type) { + if (interpolation_type == "nearest") { + return NVCV_INTERP_NEAREST; + } else if (interpolation_type == "linear") { + return NVCV_INTERP_LINEAR; + } else if (interpolation_type == "cubic") { + return NVCV_INTERP_CUBIC; + } else { + DALI_FAIL("Unknown interpolation type: " + std::string(interpolation_type)); + } +} + nvcv::DataKind GetDataKind(DALIDataType dtype) { switch (dtype) { case DALI_UINT8: @@ -147,4 +159,20 @@ void PushImagesToBatch(nvcv::ImageBatchVarShape &batch, const TensorList &tensor, TensorLayout layout = "") { + auto shape = tensor.shape(); + auto dtype = GetDataType(tensor.type(), 1); + nvcv::TensorDataStridedCuda::Buffer inBuf; + inBuf.basePtr = reinterpret_cast(const_cast(tensor.raw_data())); + inBuf.strides[shape.size() - 1] = dtype.strideBytes(); + for (int d = shape.size() - 2; d >= 0; --d) { + inBuf.strides[d] = shape[d + 1] * inBuf.strides[d + 1]; + } + TensorLayout out_layout = layout.empty() ? tensor.GetLayout() : layout; + nvcv::TensorShape out_shape(shape.data(), shape.size(), + nvcv::TensorLayout(out_layout.c_str())); + nvcv::TensorDataStridedCuda inData(out_shape, dtype, inBuf); + return nvcv::TensorWrapData(inData); +} + } // namespace dali::nvcvop diff --git a/dali/operators/nvcvop/nvcvop.h b/dali/operators/nvcvop/nvcvop.h index c864b2dab9..a38cb8a7b8 100644 --- a/dali/operators/nvcvop/nvcvop.h +++ b/dali/operators/nvcvop/nvcvop.h @@ -15,14 +15,15 @@ #ifndef DALI_OPERATORS_NVCVOP_NVCVOP_H_ #define DALI_OPERATORS_NVCVOP_NVCVOP_H_ - +#include #include #include -#include -#include #include #include +#include +#include + #include "dali/core/call_at_exit.h" #include "dali/kernels/dynamic_scratchpad.h" #include "dali/pipeline/operator/arg_helper.h" @@ -38,6 +39,13 @@ namespace dali::nvcvop { */ NVCVBorderType GetBorderMode(std::string_view border_mode); +/** + * @brief Get the nvcv interpolation type from name + * + * @param interpolation_type interpolation type name + */ +NVCVInterpolationType GetInterpolationType(std::string_view interpolation_type); + /** * @brief Get nvcv data kind of a given data type */ @@ -66,6 +74,15 @@ nvcv::Image AsImage(SampleView sample, const nvcv::ImageFormat &form */ nvcv::Image AsImage(ConstSampleView sample, const nvcv::ImageFormat &format); +/** + * @brief Wrap a DALI tensor as an NVCV Tensor + * + * @param tensor DALI tensor + * @param layout layout of the resulting nvcv::Tensor. + * If not provided, layout of the DALI tensor is used. + */ +nvcv::Tensor AsTensor(const Tensor &tensor, TensorLayout layout); + /** * @brief Allocates an image batch using a dynamic scratchpad. * Allocated images have the shape and data type matching the samples in the given TensorList. @@ -120,9 +137,9 @@ class NVCVOperator: public BaseOp { * @tparam DTYPE expected nvcv Tensor data type * @param arg_shape shape of the DALI operator argument. */ - template + template nvcv::Tensor AcquireTensorArgument(Workspace &ws, kernels::Scratchpad &scratchpad, - ArgValue &arg, const TensorShape<> &arg_shape, + ArgValue &arg, const TensorShape<> &arg_shape, nvcv::DataType dtype = GetDataType(), TensorLayout layout = "") { int num_samples = ws.GetInputBatchSize(0); diff --git a/dali/test/python/checkpointing/test_dali_stateless_operators.py b/dali/test/python/checkpointing/test_dali_stateless_operators.py index 9d206f8515..421bb7aba3 100644 --- a/dali/test/python/checkpointing/test_dali_stateless_operators.py +++ b/dali/test/python/checkpointing/test_dali_stateless_operators.py @@ -997,6 +997,12 @@ def test_erode_stateless(device): check_single_input(fn.experimental.erode, device) +@params("gpu") +@stateless_signed_off("experimental.warp_perspective") +def test_warp_perspective_stateless(device): + check_single_input(fn.experimental.warp_perspective, device) + + @params("cpu") @stateless_signed_off("zeros", "ones", "full", "zeros_like", "ones_like", "full_like") def test_full_operator_family(device): diff --git a/dali/test/python/operator_2/test_warp_perspective.py b/dali/test/python/operator_2/test_warp_perspective.py new file mode 100644 index 0000000000..147e2bc58f --- /dev/null +++ b/dali/test/python/operator_2/test_warp_perspective.py @@ -0,0 +1,372 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nvidia.dali as dali +import nvidia.dali.fn as fn +import cv2 +import numpy as np +import multiprocessing as mp +import test_utils +from nose2.tools import params + +NUM_THREADS = mp.cpu_count() +DEV_ID = 0 +SEED = 1313 + + +def ocv_border_mode(border_mode): + if border_mode == "constant": + return cv2.BORDER_CONSTANT + elif border_mode == "replicate": + return cv2.BORDER_REPLICATE + elif border_mode == "reflect": + return cv2.BORDER_REFLECT + elif border_mode == "reflect_101": + return cv2.BORDER_REFLECT_101 + elif border_mode == "wrap": + return cv2.BORDER_WRAP + else: + raise ValueError("Invalid border mode") + + +def ocv_interp_type(interp_type): + if interp_type == "nearest": + return cv2.INTER_NEAREST + elif interp_type == "linear": + return cv2.INTER_LINEAR + elif interp_type == "cubic": + return cv2.INTER_CUBIC + else: + raise ValueError("Invalid interpolation type") + + +def cv2_warp_perspective( + dst, img, matrix, layout, border_mode, interp_type, inverse_map, fill_value=None +): + border_mode = ocv_border_mode(border_mode) + interp_type = ocv_interp_type(interp_type) + flags = interp_type + + if fill_value is None: + fill_value = 0 + + if inverse_map: + flags = flags | cv2.WARP_INVERSE_MAP + if layout[-1] == "C": + dsize = (dst.shape[1], dst.shape[0]) + if not isinstance(fill_value, tuple): + fill_value = tuple([fill_value for c in range(dst.shape[2])]) + dst[:, :, :] = cv2.warpPerspective( + img, M=matrix, dsize=dsize, flags=flags, borderMode=border_mode, borderValue=fill_value + ).reshape(dst.shape) + else: + dsize = (dst.shape[2], dst.shape[1]) + for c in range(img.shape[0]): + dst[c, :, :] = cv2.warpPerspective( + img[c, :, :], + M=matrix, + dsize=dsize, + flags=flags, + borderMode=border_mode, + borderValue=fill_value, + ).reshape(dst.shape[1:]) + + +def ref_func(img, matrix, size, layout, border_mode, interp_type, inverse_map, fill_value): + out_shape = list(img.shape) + if size is not None: + out_shape[layout.find("H")] = size[0] + out_shape[layout.find("W")] = size[1] + dst = np.zeros(out_shape, dtype=img.dtype) + if layout[0] == "F": + for f in range(0, img.shape[0]): + cv2_warp_perspective( + dst[f, :, :, :], + img[f, :, :, :], + matrix, + layout, + border_mode, + interp_type, + inverse_map, + fill_value, + ) + else: + cv2_warp_perspective( + dst, img, matrix, layout, border_mode, interp_type, inverse_map, fill_value + ) + return dst + + +@dali.pipeline_def(num_threads=NUM_THREADS, device_id=DEV_ID) +def reference_pipe( + data_src, matrix_src, size, layout, border_mode, interp_type, inverse_map, fill_value +): + img = fn.external_source(source=data_src, batch=True, layout=layout) + matrix = fn.external_source(source=matrix_src) + return fn.python_function( + img, + matrix, + function=lambda im, mx: ref_func( + im, + mx, + size=size, + layout=layout, + border_mode=border_mode, + interp_type=interp_type, + inverse_map=inverse_map, + fill_value=fill_value, + ), + batch_processing=False, + output_layouts=layout, + ) + + +def matrix_source(batch_size, seed, constant=False): + np_rng = np.random.default_rng(seed=seed) + + def gen_matrix(): + dst_pts = np.array([[0, 0], [0, 100], [100, 0], [100, 100]], dtype=np.float32) + src_offsets = np_rng.random((4, 2), dtype=np.float32) * 50 - 25 + src_pts = dst_pts - src_offsets + return cv2.getPerspectiveTransform(src_pts, dst_pts).astype(np.float32) + + if constant: + m = gen_matrix() + while True: + yield [m for _ in range(batch_size)] + else: + while True: + yield [gen_matrix() for _ in range(batch_size)] + + +@dali.pipeline_def(num_threads=NUM_THREADS, device_id=DEV_ID) +def warp_perspective_pipe_const_matrix( + data_src, layout, matrix, size, border_mode, interp_type, inverse_map, fill_value +): + img = fn.external_source(source=data_src, batch=True, layout=layout, device="gpu") + return fn.experimental.warp_perspective( + img, + matrix=matrix.reshape((-1)), + size=size, + border_mode=border_mode, + interp_type=interp_type, + inverse_map=inverse_map, + fill_value=fill_value, + ) + + +@dali.pipeline_def(num_threads=NUM_THREADS, device_id=DEV_ID) +def warp_perspective_pipe_arg_inp_matrix( + data_src, matrix_src, layout, size, border_mode, interp_type, inverse_map, fill_value +): + img = fn.external_source(source=data_src, batch=True, layout=layout, device="gpu") + matrix = fn.external_source(source=matrix_src, batch=True) + matrix = fn.reshape(matrix, shape=(-1)) + return fn.experimental.warp_perspective( + img, + matrix=matrix, + size=size, + border_mode=border_mode, + interp_type=interp_type, + inverse_map=inverse_map, + fill_value=fill_value, + ) + + +@dali.pipeline_def(num_threads=NUM_THREADS, device_id=DEV_ID) +def warp_perspective_pipe_gpu_inp_matrix( + data_src, matrix_src, layout, size, border_mode, interp_type, inverse_map, fill_value +): + img = fn.external_source(source=data_src, batch=True, layout=layout, device="gpu") + matrix = fn.external_source(source=matrix_src, batch=True, device="gpu") + matrix = fn.reshape(matrix, shape=(-1)) + return fn.experimental.warp_perspective( + img, + matrix, + size=size, + border_mode=border_mode, + interp_type=interp_type, + inverse_map=inverse_map, + fill_value=fill_value, + ) + + +def input_iterator(bs, layout, dtype, channels): + cdim = layout.find("C") + min_shape = [64 for d in layout] + min_shape[cdim] = channels + max_shape = [256 for d in layout] + max_shape[cdim] = channels + if layout[0] == "F": + min_shape[0] = 8 + max_shape[0] = 32 + + return test_utils.RandomlyShapedDataIterator( + batch_size=bs, min_shape=min_shape, max_shape=max_shape, dtype=dtype, seed=SEED + ) + + +def compare_pipelines(pipe1, pipe2, bs, dtype): + if dtype == np.float32: + eps = 0.05 + elif dtype == np.uint8: + eps = 1 + elif dtype == np.int16 or dtype == np.uint16: + eps = 5 + + test_utils.compare_pipelines(pipe1, pipe2, batch_size=bs, N_iterations=10, eps=eps) + + +counter = 1 + + +@params( + (32, "HWC", np.uint8, 3, None, "constant", "nearest", False, (100, 50, 25)), + (32, "HWC", np.uint8, 3, None, "constant", "linear", False, 77), + (4, "FCHW", np.uint8, 1, (200, 300), "reflect", "nearest", True, None), + (32, "CHW", np.float32, 1, (300, 300), "replicate", "cubic", False, None), + (32, "CHW", np.float32, 4, (150, 300), "constant", "nearest", False, 55), + (8, "FHWC", np.int16, 3, None, "reflect_101", "linear", True, None), + (32, "HWC", np.uint16, 4, None, "constant", "cubic", False, None), +) +def test_warp_perspective_const_matrix_vs_ocv( + bs, layout, dtype, channels, size, border_mode, interp_type, inverse_map, fill_value +): + data1 = input_iterator(bs, layout, dtype, channels) + data2 = input_iterator(bs, layout, dtype, channels) + + global counter + matrix_src = matrix_source(bs, SEED + counter, True) + matrix = matrix_src.__next__()[0] + pipe1 = warp_perspective_pipe_const_matrix( + data1, + layout, + matrix, + size, + border_mode, + interp_type, + inverse_map, + fill_value, + batch_size=bs, + prefetch_queue_depth=1, + ) + + pipe2 = reference_pipe( + data2, + matrix_src, + size, + layout, + border_mode, + interp_type, + inverse_map, + fill_value, + batch_size=bs, + prefetch_queue_depth=1, + ) + compare_pipelines(pipe1, pipe2, bs, dtype) + counter = counter + 1 + + +@params( + (32, "HWC", np.uint8, 3, None, "constant", "nearest", False, (100, 50, 25)), + (32, "HWC", np.uint8, 3, None, "constant", "linear", False, 77), + (32, "CHW", np.float32, 1, (300, 300), "replicate", "cubic", False, None), + (4, "FCHW", np.uint8, 1, (150, 150), "reflect", "linear", True, None), + (32, "CHW", np.float32, 4, (150, 300), "constant", "nearest", False, 55), + (8, "FHWC", np.int16, 3, None, "reflect_101", "linear", True, None), + (32, "HWC", np.uint16, 4, None, "constant", "cubic", False, None), +) +def test_warp_perspective_arg_inp_matrix_vs_ocv( + bs, layout, dtype, channels, size, border_mode, interp_type, inverse_map, fill_value +): + data1 = input_iterator(bs, layout, dtype, channels) + data2 = input_iterator(bs, layout, dtype, channels) + + global counter + matrix_src1 = matrix_source(bs, SEED + counter, False) + pipe1 = warp_perspective_pipe_arg_inp_matrix( + data1, + matrix_src1, + layout, + size, + border_mode, + interp_type, + inverse_map, + fill_value, + batch_size=bs, + prefetch_queue_depth=1, + ) + + matrix_src2 = matrix_source(bs, SEED + counter, False) + pipe2 = reference_pipe( + data2, + matrix_src2, + size, + layout, + border_mode, + interp_type, + inverse_map, + fill_value, + batch_size=bs, + prefetch_queue_depth=1, + ) + compare_pipelines(pipe1, pipe2, bs, dtype) + counter = counter + 1 + + +@params( + (32, "HWC", np.uint8, 3, None, "constant", "nearest", False, (100, 50, 25)), + (32, "HWC", np.uint8, 3, None, "constant", "linear", False, 77), + (32, "CHW", np.float32, 1, (300, 300), "replicate", "cubic", False, None), + (4, "FCHW", np.uint8, 1, (150, 150), "reflect", "linear", True, None), + (32, "CHW", np.float32, 4, (150, 300), "constant", "nearest", False, 55), + (8, "FHWC", np.int16, 3, None, "reflect_101", "linear", True, None), + (32, "HWC", np.uint16, 4, None, "constant", "cubic", False, None), +) +def test_warp_perspective_gpu_inp_matrix_vs_ocv( + bs, layout, dtype, channels, size, border_mode, interp_type, inverse_map, fill_value +): + data1 = input_iterator(bs, layout, dtype, channels) + data2 = input_iterator(bs, layout, dtype, channels) + + global counter + matrix_src1 = matrix_source(bs, SEED + counter, False) + pipe1 = warp_perspective_pipe_gpu_inp_matrix( + data1, + matrix_src1, + layout, + size, + border_mode, + interp_type, + inverse_map, + fill_value, + batch_size=bs, + prefetch_queue_depth=1, + ) + + matrix_src2 = matrix_source(bs, SEED + counter, False) + pipe2 = reference_pipe( + data2, + matrix_src2, + size, + layout, + border_mode, + interp_type, + inverse_map, + fill_value, + batch_size=bs, + prefetch_queue_depth=1, + ) + compare_pipelines(pipe1, pipe2, bs, dtype) + counter = counter + 1 diff --git a/dali/test/python/test_dali_cpu_only.py b/dali/test/python/test_dali_cpu_only.py index 330ac9d201..06836894a3 100644 --- a/dali/test/python/test_dali_cpu_only.py +++ b/dali/test/python/test_dali_cpu_only.py @@ -1615,6 +1615,7 @@ def full_like_pipe(): "experimental.median_blur", # not supported for CPU "experimental.dilate", # not supported for CPU "experimental.erode", # not supported for CPU + "experimental.warp_perspective", # not supported for CPU "plugin.video.decoder", # not supported for CPU ] diff --git a/dali/test/python/test_dali_variable_batch_size.py b/dali/test/python/test_dali_variable_batch_size.py index ce52cb722c..d4b5927c21 100644 --- a/dali/test/python/test_dali_variable_batch_size.py +++ b/dali/test/python/test_dali_variable_batch_size.py @@ -374,6 +374,7 @@ def numba_setup_out_shape(out_shape, in_shape): (fn.experimental.median_blur, {"devices": ["gpu"]}), (fn.experimental.dilate, {"devices": ["gpu"]}), (fn.experimental.erode, {"devices": ["gpu"]}), + (fn.experimental.warp_perspective, {"devices": ["gpu"]}), (fn.zeros_like, {"devices": ["cpu"]}), (fn.ones_like, {"devices": ["cpu"]}), ] @@ -1643,6 +1644,7 @@ def pipe(max_batch_size, input_data, device): "experimental.inflate", "experimental.median_blur", "experimental.remap", + "experimental.warp_perspective", "external_source", "fast_resize_crop_mirror", "flip",