From ff0bd1c7aac10e3d3cb89be114e3154dc4f48962 Mon Sep 17 00:00:00 2001 From: Junrou Nishida Date: Tue, 2 Jan 2024 22:50:44 +0900 Subject: [PATCH] feat: marshal DetectionResult (#1089) * feat: marshal DetectionResult * refactor: init list if null --- .../Runtime/Scripts/Marshal/NativeCategory.cs | 38 +++++++++++ .../Scripts/Marshal/NativeCategory.cs.meta | 11 +++ .../Scripts/Marshal/NativeDetectionResult.cs | 67 +++++++++++++++++++ .../Marshal/NativeDetectionResult.cs.meta | 11 +++ .../Runtime/Scripts/Marshal/NativeKeypoint.cs | 23 +++++++ .../Scripts/Marshal/NativeKeypoint.cs.meta | 11 +++ .../Runtime/Scripts/Marshal/NativeRect.cs | 28 ++++++++ .../Scripts/Marshal/NativeRect.cs.meta | 11 +++ .../Tasks/Components/Containers_Unsafe.cs | 6 ++ .../Tasks/Components/Containers/Category.cs | 8 +++ .../Components/Containers/DetectionResult.cs | 50 ++++++++++---- .../Tasks/Components/Containers/Keypoint.cs | 10 +++ .../Components/Containers/PacketExtension.cs | 7 ++ .../Tasks/Components/Containers/Rect.cs | 12 ++++ mediapipe_api/BUILD | 1 + .../components/containers/detection_result.cc | 15 +++++ .../components/containers/detection_result.h | 23 +++++++ 17 files changed, 320 insertions(+), 12 deletions(-) create mode 100644 Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeCategory.cs create mode 100644 Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeCategory.cs.meta create mode 100644 Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeDetectionResult.cs create mode 100644 Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeDetectionResult.cs.meta create mode 100644 Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeKeypoint.cs create mode 100644 Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeKeypoint.cs.meta create mode 100644 Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeRect.cs create mode 100644 Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeRect.cs.meta create mode 100644 mediapipe_api/tasks/c/components/containers/detection_result.cc create mode 100644 mediapipe_api/tasks/c/components/containers/detection_result.h diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeCategory.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeCategory.cs new file mode 100644 index 000000000..eca8dca37 --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeCategory.cs @@ -0,0 +1,38 @@ +// Copyright (c) 2023 homuler +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +using System; +using System.Runtime.InteropServices; + +namespace Mediapipe +{ + [StructLayout(LayoutKind.Sequential)] + internal readonly struct NativeCategory + { + public readonly int index; + public readonly float score; + private readonly IntPtr _categoryName; + private readonly IntPtr _displayName; + + public string categoryName => Marshal.PtrToStringAnsi(_categoryName); + public string displayName => Marshal.PtrToStringAnsi(_displayName); + } + + [StructLayout(LayoutKind.Sequential)] + internal readonly struct NativeCategories + { + private readonly IntPtr _categories; + public readonly uint categoriesCount; + + public ReadOnlySpan AsReadOnlySpan() + { + unsafe + { + return new ReadOnlySpan((NativeCategory*)_categories, (int)categoriesCount); + } + } + } +} diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeCategory.cs.meta b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeCategory.cs.meta new file mode 100644 index 000000000..779888b73 --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeCategory.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: eca575da1f52761b39bed46467a81acb +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeDetectionResult.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeDetectionResult.cs new file mode 100644 index 000000000..efb9256f6 --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeDetectionResult.cs @@ -0,0 +1,67 @@ +// Copyright (c) 2023 homuler +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +using System; +using System.Runtime.InteropServices; + +namespace Mediapipe +{ + [StructLayout(LayoutKind.Sequential)] + internal readonly struct NativeDetection + { + private readonly IntPtr _categories; + + public readonly uint categoriesCount; + + public readonly NativeRect boundingBox; + + private readonly IntPtr _keypoints; + + public readonly uint keypointsCount; + + public ReadOnlySpan categories + { + get + { + unsafe + { + return new ReadOnlySpan((NativeCategory*)_categories, (int)categoriesCount); + } + } + } + + public ReadOnlySpan keypoints + { + get + { + unsafe + { + return new ReadOnlySpan((NativeNormalizedKeypoint*)_keypoints, (int)keypointsCount); + } + } + } + } + + [StructLayout(LayoutKind.Sequential)] + internal readonly struct NativeDetectionResult + { + private readonly IntPtr _detections; + public readonly uint detectionsCount; + + public ReadOnlySpan AsReadOnlySpan() + { + unsafe + { + return new ReadOnlySpan((NativeDetection*)_detections, (int)detectionsCount); + } + } + + public void Dispose() + { + UnsafeNativeMethods.mp_tasks_c_components_containers_CppCloseDetectionResult(this); + } + } +} diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeDetectionResult.cs.meta b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeDetectionResult.cs.meta new file mode 100644 index 000000000..9b1daa49f --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeDetectionResult.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 606373bac36c7f322a0afca07b96d8ab +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeKeypoint.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeKeypoint.cs new file mode 100644 index 000000000..b78fd74ad --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeKeypoint.cs @@ -0,0 +1,23 @@ +// Copyright (c) 2023 homuler +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +using System; +using System.Runtime.InteropServices; + +namespace Mediapipe +{ + [StructLayout(LayoutKind.Sequential)] + internal readonly struct NativeNormalizedKeypoint + { + public readonly float x; + public readonly float y; + private readonly IntPtr _label; + public readonly float score; + public readonly bool hasScore; + + public string label => Marshal.PtrToStringAnsi(_label); + } +} diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeKeypoint.cs.meta b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeKeypoint.cs.meta new file mode 100644 index 000000000..63e48ed04 --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeKeypoint.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 6eff818affa034535867ef24c7f6fe39 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeRect.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeRect.cs new file mode 100644 index 000000000..064eeba86 --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeRect.cs @@ -0,0 +1,28 @@ +// Copyright (c) 2023 homuler +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +using System.Runtime.InteropServices; + +namespace Mediapipe +{ + [StructLayout(LayoutKind.Sequential)] + internal readonly struct NativeRect + { + public readonly int left; + public readonly int top; + public readonly int bottom; + public readonly int right; + } + + [StructLayout(LayoutKind.Sequential)] + internal readonly struct NativeRectF + { + public readonly float left; + public readonly float top; + public readonly float bottom; + public readonly float right; + } +} diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeRect.cs.meta b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeRect.cs.meta new file mode 100644 index 000000000..4908c12fc --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeRect.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 85baeb5f69276357fb7fe9c3868aec6b +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/Components/Containers_Unsafe.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/Components/Containers_Unsafe.cs index 7182d3bef..cdc08caef 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/Components/Containers_Unsafe.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/Components/Containers_Unsafe.cs @@ -11,6 +11,12 @@ namespace Mediapipe { internal static partial class UnsafeNativeMethods { + [DllImport(MediaPipeLibrary, ExactSpelling = true)] + public static extern MpReturnCode mp_Packet__GetDetectionResult(IntPtr packet, out NativeDetectionResult value); + + [DllImport(MediaPipeLibrary, ExactSpelling = true)] + public static extern void mp_tasks_c_components_containers_CppCloseDetectionResult(NativeDetectionResult data); + [DllImport(MediaPipeLibrary, ExactSpelling = true)] public static extern MpReturnCode mp_Packet__GetNormalizedLandmarksVector(IntPtr packet, out NativeNormalizedLandmarksArray value); diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Category.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Category.cs index d370fcdbd..1a17b193a 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Category.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Category.cs @@ -43,6 +43,14 @@ internal Category(int index, float score, string categoryName, string displayNam this.displayName = displayName; } + internal Category(NativeCategory nativeCategory) : this( + nativeCategory.index, + nativeCategory.score, + nativeCategory.categoryName, + nativeCategory.displayName) + { + } + public static Category CreateFrom(Classification proto) { var categoryName = proto.HasLabel ? proto.Label : null; diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/DetectionResult.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/DetectionResult.cs index e07686cb1..0ed592666 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/DetectionResult.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/DetectionResult.cs @@ -4,7 +4,6 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -using System; using System.Collections.Generic; namespace Mediapipe.Tasks.Components.Containers @@ -53,11 +52,11 @@ public static Detection CreateFrom(Mediapipe.Detection proto) public static void Copy(Mediapipe.Detection proto, ref Detection destination) { - var categories = destination.categories; + var categories = destination.categories ?? new List(proto.Score.Count); categories.Clear(); for (var idx = 0; idx < proto.Score.Count; idx++) { - destination.categories.Add(new Category( + categories.Add(new Category( proto.LabelId.Count > idx ? proto.LabelId[idx] : _DefaultCategoryIndex, proto.Score[idx], proto.Label.Count > idx ? proto.Label[idx] : "", @@ -96,6 +95,27 @@ public static void Copy(Mediapipe.Detection proto, ref Detection destination) destination = new Detection(categories, boundingBox, keypoints); } + internal static void Copy(in NativeDetection source, ref Detection destination) + { + var categories = destination.categories ?? new List((int)source.categoriesCount); + categories.Clear(); + foreach (var nativeCategory in source.categories) + { + categories.Add(new Category(nativeCategory)); + } + + var boundingBox = new Rect(source.boundingBox); + + var keypoints = destination.keypoints ?? new List((int)source.keypointsCount); + keypoints.Clear(); + foreach (var nativeKeypoint in source.keypoints) + { + keypoints.Add(new NormalizedKeypoint(nativeKeypoint)); + } + + destination = new Detection(categories, boundingBox, keypoints); + } + public override string ToString() => $"{{ \"categories\": {Util.Format(categories)}, \"boundingBox\": {boundingBox}, \"keypoints\": {Util.Format(keypoints)} }}"; } @@ -130,22 +150,28 @@ internal static DetectionResult CreateFrom(List detectionsP internal static void Copy(List source, ref DetectionResult destination) { - var detections = destination.detections; - if (source.Count < detections.Count) - { - detections.RemoveRange(source.Count, detections.Count - source.Count); - } - var copyCount = Math.Min(source.Count, detections.Count); - for (var i = 0; i < copyCount; i++) + var detections = destination.detections ?? new List(source.Count); + detections.ResizeTo(source.Count); + + for (var i = 0; i < source.Count; i++) { var detection = detections[i]; Detection.Copy(source[i], ref detection); detections[i] = detection; } + } - for (var i = copyCount; i < source.Count; i++) + internal static void Copy(NativeDetectionResult source, ref DetectionResult destination) + { + var detections = destination.detections ?? new List((int)source.detectionsCount); + detections.ResizeTo((int)source.detectionsCount); + + var i = 0; + foreach (var nativeDetection in source.AsReadOnlySpan()) { - detections.Add(Detection.CreateFrom(source[i])); + var detection = detections[i]; + Detection.Copy(nativeDetection, ref detection); + detections[i++] = detection; } } diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Keypoint.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Keypoint.cs index 4f3f7af48..5a3b9da43 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Keypoint.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Keypoint.cs @@ -36,6 +36,16 @@ internal NormalizedKeypoint(float x, float y, string label, float? score) this.score = score; } + internal NormalizedKeypoint(NativeNormalizedKeypoint nativeKeypoint) : this( + nativeKeypoint.x, + nativeKeypoint.y, + nativeKeypoint.label, +#pragma warning disable IDE0004 // for Unity 2020.3.x + nativeKeypoint.hasScore ? (float?)nativeKeypoint.score : null) +#pragma warning restore IDE0004 // for Unity 2020.3.x + { + } + public override string ToString() => $"{{ \"x\": {x}, \"y\": {y}, \"label\": \"{label}\", \"score\": {Util.Format(score)} }}"; } } diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/PacketExtension.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/PacketExtension.cs index e416634c5..57543c8ed 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/PacketExtension.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/PacketExtension.cs @@ -10,6 +10,13 @@ namespace Mediapipe.Tasks.Components.Containers { public static class PacketExtension { + public static void GetDetectionResult(this Packet packet, ref DetectionResult value) + { + UnsafeNativeMethods.mp_Packet__GetDetectionResult(packet.mpPtr, out var detectionResult).Assert(); + DetectionResult.Copy(detectionResult, ref value); + detectionResult.Dispose(); + } + public static void GetNormalizedLandmarksList(this Packet packet, List outs) { UnsafeNativeMethods.mp_Packet__GetNormalizedLandmarksVector(packet.mpPtr, out var landmarksArray).Assert(); diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Rect.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Rect.cs index d8c2b5f08..0a6b5ba8f 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Rect.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/Rect.cs @@ -29,6 +29,8 @@ internal Rect(int left, int top, int right, int bottom) this.bottom = bottom; } + internal Rect(NativeRect nativeRect) : this(nativeRect.left, nativeRect.top, nativeRect.right, nativeRect.bottom) { } + public override string ToString() => $"{{ \"left\": {left}, \"top\": {top}, \"right\": {right}, \"bottom\": {bottom} }}"; } @@ -48,6 +50,16 @@ internal Rect(int left, int top, int right, int bottom) public readonly float right; public readonly float bottom; + internal RectF(float left, float top, float right, float bottom) + { + this.left = left; + this.top = top; + this.right = right; + this.bottom = bottom; + } + + internal RectF(NativeRectF nativeRect) : this(nativeRect.left, nativeRect.top, nativeRect.right, nativeRect.bottom) { } + #nullable enable public override bool Equals(object? obj) => obj is RectF other && Equals(other); #nullable disable diff --git a/mediapipe_api/BUILD b/mediapipe_api/BUILD index 66f8af24d..2c0624812 100644 --- a/mediapipe_api/BUILD +++ b/mediapipe_api/BUILD @@ -164,6 +164,7 @@ cc_library( "//mediapipe_api/framework/formats:matrix_data", "//mediapipe_api/framework/formats:rect", "//mediapipe_api/framework/port:logging", + "//mediapipe_api/tasks/c/components/containers:detection_result", "//mediapipe_api/tasks/c/components/containers:landmark", "//mediapipe_api/tasks/cc/vision/face_geometry/proto:face_geometry", "//mediapipe_api/tasks/cc/core:task_runner", diff --git a/mediapipe_api/tasks/c/components/containers/detection_result.cc b/mediapipe_api/tasks/c/components/containers/detection_result.cc new file mode 100644 index 000000000..60bcf5cb4 --- /dev/null +++ b/mediapipe_api/tasks/c/components/containers/detection_result.cc @@ -0,0 +1,15 @@ +#include "mediapipe_api/tasks/c/components/containers/detection_result.h" + +MpReturnCode mp_Packet__GetDetectionResult(mediapipe::Packet* packet, DetectionResult* value_out) { + TRY_ALL + // get std::vector and convert it to DetectionResult* + auto detections = packet->Get>(); + auto detection_result = mediapipe::tasks::components::containers::ConvertToDetectionResult(detections); + mediapipe::tasks::c::components::containers::CppConvertToDetectionResult(detection_result, value_out); + RETURN_CODE(MpReturnCode::Success); + CATCH_ALL +} + +void mp_tasks_c_components_containers_CppCloseDetectionResult(DetectionResult data) { + mediapipe::tasks::c::components::containers::CppCloseDetectionResult(&data); +} diff --git a/mediapipe_api/tasks/c/components/containers/detection_result.h b/mediapipe_api/tasks/c/components/containers/detection_result.h new file mode 100644 index 000000000..06d19e8e3 --- /dev/null +++ b/mediapipe_api/tasks/c/components/containers/detection_result.h @@ -0,0 +1,23 @@ +// Copyright (c) 2023 homuler +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +#ifndef MEDIAPIPE_API_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_H_ +#define MEDIAPIPE_API_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_H_ + +#include "mediapipe/framework/packet.h" +#include "mediapipe/tasks/c/components/containers/detection_result.h" +#include "mediapipe/tasks/c/components/containers/detection_result_converter.h" +#include "mediapipe/tasks/cc/components/containers/detection_result.h" +#include "mediapipe_api/common.h" + +extern "C" { + +MP_CAPI(MpReturnCode) mp_Packet__GetDetectionResult(mediapipe::Packet* packet, DetectionResult* value_out); +MP_CAPI(void) mp_tasks_c_components_containers_CppCloseDetectionResult(DetectionResult data); + +} // extern "C" + +#endif // MEDIAPIPE_API_TASKS_C_COMPONENTS_CONTAINERS_DETECTION_H_