diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeClassificationResult.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeClassificationResult.cs new file mode 100644 index 000000000..3bc46740f --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeClassificationResult.cs @@ -0,0 +1,58 @@ +// Copyright (c) 2023 homuler +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +using System; +using System.Runtime.InteropServices; + +namespace Mediapipe +{ + [StructLayout(LayoutKind.Sequential)] + internal readonly struct NativeClassifications + { + private readonly IntPtr _categories; + public readonly uint categoriesCount; + public readonly int headIndex; + private readonly IntPtr _headName; + + public ReadOnlySpan categories + { + get + { + unsafe + { + return new ReadOnlySpan((NativeCategory*)_categories, (int)categoriesCount); + } + } + } + + public string headName => Marshal.PtrToStringAnsi(_headName); + } + + [StructLayout(LayoutKind.Sequential)] + internal readonly struct NativeClassificationResult + { + private readonly IntPtr _classifications; + public readonly uint classificationsCount; + public readonly long timestampMs; + public readonly bool hasTimestampMs; + + public ReadOnlySpan classifications + { + get + { + unsafe + { + return new ReadOnlySpan((NativeClassifications*)_classifications, (int)classificationsCount); + } + } + } + + public void Dispose() + { + UnsafeNativeMethods.mp_tasks_c_components_containers_CppCloseClassificationResult(this); + } + } +} diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeClassificationResult.cs.meta b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeClassificationResult.cs.meta new file mode 100644 index 000000000..751bb0ac1 --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Marshal/NativeClassificationResult.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ad3e5c6d1c8049810b48387e9cacfad6 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/Components/Containers_Unsafe.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/Components/Containers_Unsafe.cs index cdc08caef..4f46209f2 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/Components/Containers_Unsafe.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/Components/Containers_Unsafe.cs @@ -11,6 +11,12 @@ namespace Mediapipe { internal static partial class UnsafeNativeMethods { + [DllImport(MediaPipeLibrary, ExactSpelling = true)] + public static extern MpReturnCode mp_Packet__GetClassificationsVector(IntPtr packet, out NativeClassificationResult value); + + [DllImport(MediaPipeLibrary, ExactSpelling = true)] + public static extern void mp_tasks_c_components_containers_CppCloseClassificationResult(NativeClassificationResult data); + [DllImport(MediaPipeLibrary, ExactSpelling = true)] public static extern MpReturnCode mp_Packet__GetDetectionResult(IntPtr packet, out NativeDetectionResult value); diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/ClassificationResult.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/ClassificationResult.cs index ddd9b3861..1f3cbec9a 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/ClassificationResult.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/ClassificationResult.cs @@ -17,7 +17,7 @@ public readonly struct Classifications /// The array of predicted categories, usually sorted by descending scores, /// e.g. from high to low probability. /// - public readonly IReadOnlyList categories; + public readonly List categories; /// /// The index of the classifier head (i.e. output tensor) these categories /// refer to. This is useful for multi-head models. @@ -31,7 +31,7 @@ public readonly struct Classifications /// public readonly string headName; - internal Classifications(IReadOnlyList categories, int headIndex, string headName) + internal Classifications(List categories, int headIndex, string headName) { this.categories = categories; this.headIndex = headIndex; @@ -58,6 +58,17 @@ public static Classifications CreateFrom(ClassificationList proto, int headIndex return new Classifications(categories, headIndex, headName); } + internal static void Copy(NativeClassifications source, ref Classifications destination) + { + var categories = destination.categories ?? new List((int)source.categoriesCount); + categories.Clear(); + foreach (var nativeCategory in source.categories) + { + categories.Add(new Category(nativeCategory)); + } + destination = new Classifications(categories, source.headIndex, source.headName); + } + public override string ToString() => $"{{ \"categories\": {Util.Format(categories)}, \"headIndex\": {headIndex}, \"headName\": {Util.Format(headName)} }}"; } @@ -70,7 +81,7 @@ public readonly struct ClassificationResult /// /// The classification results for each head of the model. /// - public readonly IReadOnlyList classifications; + public readonly List classifications; /// /// The optional timestamp (in milliseconds) of the start of the chunk of data @@ -83,7 +94,7 @@ public readonly struct ClassificationResult /// public readonly long? timestampMs; - internal ClassificationResult(IReadOnlyList classifications, long? timestampMs) + internal ClassificationResult(List classifications, long? timestampMs) { this.classifications = classifications; this.timestampMs = timestampMs; @@ -101,6 +112,22 @@ public static ClassificationResult CreateFrom(Proto.ClassificationResult proto) #pragma warning restore IDE0004 // for Unity 2020.3.x } + internal static void Copy(NativeClassificationResult source, ref ClassificationResult destination) + { + var classificationsList = destination.classifications ?? new List((int)source.classificationsCount); + classificationsList.ResizeTo((int)source.classificationsCount); + + var i = 0; + foreach (var nativeClassifications in source.classifications) + { + var classifications = classificationsList[i]; + Classifications.Copy(nativeClassifications, ref classifications); + classificationsList[i++] = classifications; + } + + destination = new ClassificationResult(classificationsList, source.hasTimestampMs ? (long?)source.timestampMs : null); + } + public override string ToString() => $"{{ \"classifications\": {Util.Format(classifications)}, \"timestampMs\": {Util.Format(timestampMs)} }}"; } } diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/PacketExtension.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/PacketExtension.cs index 57543c8ed..e27c3ee09 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/PacketExtension.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Components/Containers/PacketExtension.cs @@ -10,6 +10,14 @@ namespace Mediapipe.Tasks.Components.Containers { public static class PacketExtension { + public static void GetClassificationsVector(this Packet packet, List outs) + { + UnsafeNativeMethods.mp_Packet__GetClassificationsVector(packet.mpPtr, out var classificationResult).Assert(); + var tmp = new ClassificationResult(outs, null); + ClassificationResult.Copy(classificationResult, ref tmp); + classificationResult.Dispose(); + } + public static void GetDetectionResult(this Packet packet, ref DetectionResult value) { UnsafeNativeMethods.mp_Packet__GetDetectionResult(packet.mpPtr, out var detectionResult).Assert(); diff --git a/mediapipe_api/BUILD b/mediapipe_api/BUILD index 2c0624812..94b42a45c 100644 --- a/mediapipe_api/BUILD +++ b/mediapipe_api/BUILD @@ -164,6 +164,7 @@ cc_library( "//mediapipe_api/framework/formats:matrix_data", "//mediapipe_api/framework/formats:rect", "//mediapipe_api/framework/port:logging", + "//mediapipe_api/tasks/c/components/containers:classification_result", "//mediapipe_api/tasks/c/components/containers:detection_result", "//mediapipe_api/tasks/c/components/containers:landmark", "//mediapipe_api/tasks/cc/vision/face_geometry/proto:face_geometry", diff --git a/mediapipe_api/tasks/c/components/containers/BUILD b/mediapipe_api/tasks/c/components/containers/BUILD index de56a1f8d..1313912c6 100644 --- a/mediapipe_api/tasks/c/components/containers/BUILD +++ b/mediapipe_api/tasks/c/components/containers/BUILD @@ -8,6 +8,21 @@ package( default_visibility = ["//visibility:public"], ) +cc_library( + name = "classification_result", + srcs = ["classification_result.cc"], + hdrs = ["classification_result.h"], + deps = [ + "//mediapipe_api:common", + "@com_google_mediapipe//mediapipe/framework:packet", + "@com_google_mediapipe//mediapipe/framework/formats:classification_cc_proto", + "@com_google_mediapipe//mediapipe/tasks/c/components/containers:classification_result", + "@com_google_mediapipe//mediapipe/tasks/c/components/containers:classification_result_converter", + "@com_google_mediapipe//mediapipe/tasks/cc/components/containers:classification_result", + ], + alwayslink = True, +) + cc_library( name = "detection_result", srcs = ["detection_result.cc"], diff --git a/mediapipe_api/tasks/c/components/containers/classification_result.cc b/mediapipe_api/tasks/c/components/containers/classification_result.cc new file mode 100644 index 000000000..fe8f9fcc7 --- /dev/null +++ b/mediapipe_api/tasks/c/components/containers/classification_result.cc @@ -0,0 +1,25 @@ +#include "mediapipe_api/tasks/c/components/containers/classification_result.h" + +MpReturnCode mp_Packet__GetClassificationsVector(mediapipe::Packet* packet, ClassificationResult* value_out) { + TRY_ALL + // get std::vector and convert it to ClassificationResult + auto proto_vec = packet->Get>(); + auto vec_size = proto_vec.size(); + + auto classifications_vec = std::vector(vec_size); + for (auto i = 0; i < vec_size; ++i) { + auto classifications = mediapipe::tasks::components::containers::ConvertToClassifications(proto_vec[i]); + classifications_vec.push_back(classifications); + } + + mediapipe::tasks::components::containers::ClassificationResult result; + result.classifications = std::move(classifications_vec); + + mediapipe::tasks::c::components::containers::CppConvertToClassificationResult(result, value_out); + RETURN_CODE(MpReturnCode::Success); + CATCH_ALL +} + +void mp_tasks_c_components_containers_CppCloseClassificationResult(ClassificationResult data) { + mediapipe::tasks::c::components::containers::CppCloseClassificationResult(&data); +} diff --git a/mediapipe_api/tasks/c/components/containers/classification_result.h b/mediapipe_api/tasks/c/components/containers/classification_result.h new file mode 100644 index 000000000..ef8901f97 --- /dev/null +++ b/mediapipe_api/tasks/c/components/containers/classification_result.h @@ -0,0 +1,23 @@ +// Copyright (c) 2023 homuler +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +#ifndef MEDIAPIPE_API_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ +#define MEDIAPIPE_API_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ + +#include "mediapipe/framework/packet.h" +#include "mediapipe/tasks/c/components/containers/classification_result.h" +#include "mediapipe/tasks/c/components/containers/classification_result_converter.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe_api/common.h" + +extern "C" { + +MP_CAPI(MpReturnCode) mp_Packet__GetClassificationsVector(mediapipe::Packet* packet, ClassificationResult* value_out); +MP_CAPI(void) mp_tasks_c_components_containers_CppCloseClassificationResult(ClassificationResult data); + +} // extern "C" + +#endif // MEDIAPIPE_API_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_