From eeeec337739b4a96912d0f3a57ab3d2822e31f28 Mon Sep 17 00:00:00 2001 From: Junrou Nishida Date: Sat, 29 Jul 2023 19:32:40 +0900 Subject: [PATCH] feat: implement PacketsCallbackTable (#971) * refactor: TaskRunner passes callback_id to the callback * feat: implement PacketsCallbackTable * refactor: dispose PacketMap when Process / Send called * fix: packets_callback must be null in some cases --- .../NativeMethods/Tasks/TaskRunner_Unsafe.cs | 2 +- .../Tasks/Core/PacketsCallbackTable.cs | 50 ++++++++++++++++ .../Tasks/Core/PacketsCallbackTable.cs.meta | 11 ++++ .../Runtime/Scripts/Tasks/Core/TaskRunner.cs | 30 +++++----- .../Tasks/Vision/Core/BaseVisionTaskApi.cs | 14 ++--- .../EditMode/Tasks/Core/TaskRunnerTest.cs | 57 +++++++++---------- mediapipe_api/tasks/cc/core/task_runner.cc | 13 +++-- mediapipe_api/tasks/cc/core/task_runner.h | 5 +- 8 files changed, 118 insertions(+), 64 deletions(-) create mode 100644 Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/PacketsCallbackTable.cs create mode 100644 Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/PacketsCallbackTable.cs.meta diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/TaskRunner_Unsafe.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/TaskRunner_Unsafe.cs index 412cc0fe1..37ef5aa51 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/TaskRunner_Unsafe.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Tasks/TaskRunner_Unsafe.cs @@ -13,7 +13,7 @@ internal static partial class UnsafeNativeMethods { [DllImport(MediaPipeLibrary, ExactSpelling = true)] public static extern MpReturnCode mp_tasks_core_TaskRunner_Create__PKc_i_PF(byte[] serializedConfig, int size, - [MarshalAs(UnmanagedType.FunctionPtr)] Tasks.Core.TaskRunner.NativePacketsCallback packetsCallback, + int callbackId, [MarshalAs(UnmanagedType.FunctionPtr)] Tasks.Core.TaskRunner.NativePacketsCallback packetsCallback, out IntPtr status, out IntPtr taskRunner); [DllImport(MediaPipeLibrary, ExactSpelling = true)] diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/PacketsCallbackTable.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/PacketsCallbackTable.cs new file mode 100644 index 000000000..19870e0f1 --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/PacketsCallbackTable.cs @@ -0,0 +1,50 @@ +// Copyright (c) 2023 homuler +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +using System; +using UnityEngine; + +namespace Mediapipe.Tasks.Core +{ + internal class PacketsCallbackTable + { + private const int _MaxSize = 20; + + private static int _Counter = 0; + private static readonly GlobalInstanceTable _Table = new GlobalInstanceTable(_MaxSize); + + public static (int, TaskRunner.NativePacketsCallback) Add(TaskRunner.PacketsCallback callback) + { + if (callback == null) + { + return (-1, null); + } + + var callbackId = _Counter++; + _Table.Add(callbackId, callback); + return (callbackId, InvokeCallbackIfFound); + } + + public static bool TryGetValue(int id, out TaskRunner.PacketsCallback callback) => _Table.TryGetValue(id, out callback); + + [AOT.MonoPInvokeCallback(typeof(TaskRunner.NativePacketsCallback))] + private static void InvokeCallbackIfFound(int callbackId, IntPtr statusPtr, IntPtr packetMapPtr) + { + // NOTE: if status is not OK, packetMap will be nullptr + if (packetMapPtr == IntPtr.Zero) + { + var status = new Status(statusPtr, false); + Debug.LogError(status.ToString()); + return; + } + + if (TryGetValue(callbackId, out var callback)) + { + callback(new PacketMap(packetMapPtr, false)); + } + } + } +} diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/PacketsCallbackTable.cs.meta b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/PacketsCallbackTable.cs.meta new file mode 100644 index 000000000..c81d58208 --- /dev/null +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/PacketsCallbackTable.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: f55cdb6359fc58c229fcc611deab5d1d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/TaskRunner.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/TaskRunner.cs index 7901c1198..0b50d3bc5 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/TaskRunner.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Core/TaskRunner.cs @@ -11,16 +11,15 @@ namespace Mediapipe.Tasks.Core { public class TaskRunner : MpResourceHandle { - public delegate void NativePacketsCallback(IntPtr status, IntPtr packetMap); + public delegate void NativePacketsCallback(int name, IntPtr status, IntPtr packetMap); + public delegate void PacketsCallback(PacketMap packetMap); - public static TaskRunner Create(CalculatorGraphConfig config, NativePacketsCallback packetsCallback = null) + public static TaskRunner Create(CalculatorGraphConfig config, int callbackId = -1, NativePacketsCallback packetsCallback = null) { var bytes = config.ToByteArray(); - UnsafeNativeMethods.mp_tasks_core_TaskRunner_Create__PKc_i_PF(bytes, bytes.Length, packetsCallback, out var statusPtr, out var taskRunnerPtr).Assert(); - - var status = new Status(statusPtr); - status.AssertOk(); + UnsafeNativeMethods.mp_tasks_core_TaskRunner_Create__PKc_i_PF(bytes, bytes.Length, callbackId, packetsCallback, out var statusPtr, out var taskRunnerPtr).Assert(); + AssertStatusOk(statusPtr); return new TaskRunner(taskRunnerPtr); } @@ -34,21 +33,20 @@ protected override void DeleteMpPtr() public PacketMap Process(PacketMap inputs) { UnsafeNativeMethods.mp_tasks_core_TaskRunner__Process__Ppm(mpPtr, inputs.mpPtr, out var statusPtr, out var packetMapPtr).Assert(); - GC.KeepAlive(this); - - var status = new Status(statusPtr); - status.AssertOk(); + inputs.Dispose(); // respect move semantics + GC.KeepAlive(this); + AssertStatusOk(statusPtr); return new PacketMap(packetMapPtr, true); } public void Send(PacketMap inputs) { UnsafeNativeMethods.mp_tasks_core_TaskRunner__Send__Ppm(mpPtr, inputs.mpPtr, out var statusPtr).Assert(); - GC.KeepAlive(this); + inputs.Dispose(); // respect move semantics - var status = new Status(statusPtr); - status.AssertOk(); + GC.KeepAlive(this); + AssertStatusOk(statusPtr); } public void Close() @@ -56,8 +54,7 @@ public void Close() UnsafeNativeMethods.mp_tasks_core_TaskRunner__Close(mpPtr, out var statusPtr).Assert(); GC.KeepAlive(this); - var status = new Status(statusPtr); - status.AssertOk(); + AssertStatusOk(statusPtr); } public void Restart() @@ -65,8 +62,7 @@ public void Restart() UnsafeNativeMethods.mp_tasks_core_TaskRunner__Restart(mpPtr, out var statusPtr).Assert(); GC.KeepAlive(this); - var status = new Status(statusPtr); - status.AssertOk(); + AssertStatusOk(statusPtr); } public CalculatorGraphConfig GetGraphConfig(ExtensionRegistry extensionRegistry = null) diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Vision/Core/BaseVisionTaskApi.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Vision/Core/BaseVisionTaskApi.cs index 24583f750..d77af2df2 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Vision/Core/BaseVisionTaskApi.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Tasks/Vision/Core/BaseVisionTaskApi.cs @@ -26,21 +26,22 @@ public class BaseVisionTaskApi : IDisposable protected BaseVisionTaskApi( CalculatorGraphConfig graphConfig, RunningMode runningMode, - Tasks.Core.TaskRunner.NativePacketsCallback packetCallback) + Tasks.Core.TaskRunner.PacketsCallback packetsCallback) { if (runningMode == RunningMode.LIVE_STREAM) { - if (packetCallback == null) + if (packetsCallback == null) { throw new ArgumentException("The vision task is in live stream mode, a user-defined result callback must be provided."); } } - else if (packetCallback != null) + else if (packetsCallback != null) { throw new ArgumentException("The vision task is in image or video mode, a user-defined result callback should not be provided."); } - _taskRunner = Tasks.Core.TaskRunner.Create(graphConfig, packetCallback); + var (callbackId, nativePacketsCallback) = Tasks.Core.PacketsCallbackTable.Add(packetsCallback); + _taskRunner = Tasks.Core.TaskRunner.Create(graphConfig, callbackId, nativePacketsCallback); _runningMode = runningMode; } @@ -175,10 +176,7 @@ public void Close() /// /// Returns the canonicalized CalculatorGraphConfig of the underlying graph. /// - public CalculatorGraphConfig GetGraphConfig() - { - return _taskRunner.GetGraphConfig(); - } + public CalculatorGraphConfig GetGraphConfig() => _taskRunner.GetGraphConfig(); void IDisposable.Dispose() { diff --git a/Packages/com.github.homuler.mediapipe/Tests/EditMode/Tasks/Core/TaskRunnerTest.cs b/Packages/com.github.homuler.mediapipe/Tests/EditMode/Tasks/Core/TaskRunnerTest.cs index c82841598..4efcc961e 100644 --- a/Packages/com.github.homuler.mediapipe/Tests/EditMode/Tasks/Core/TaskRunnerTest.cs +++ b/Packages/com.github.homuler.mediapipe/Tests/EditMode/Tasks/Core/TaskRunnerTest.cs @@ -73,11 +73,10 @@ public void Process_ShouldThrowException_When_InputIsInvalid() { using (var taskRunner = TaskRunner.Create(passThroughConfig)) { - using (var packetMap = new PacketMap()) - { - var exception = Assert.Throws(() => taskRunner.Process(packetMap)); - Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode); - } + var packetMap = new PacketMap(); + var exception = Assert.Throws(() => taskRunner.Process(packetMap)); + Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode); + Assert.True(packetMap.isDisposed); } } @@ -86,12 +85,12 @@ public void Process_ShouldReturnOutput_When_InputIsValid() { using (var taskRunner = TaskRunner.Create(passThroughConfig)) { - using (var packetMap = new PacketMap()) - { - packetMap.Emplace("in", new IntPacket(1)); - var outputMap = taskRunner.Process(packetMap); - Assert.AreEqual(1, outputMap.At("out").Get()); - } + var packetMap = new PacketMap(); + packetMap.Emplace("in", new IntPacket(1)); + + var outputMap = taskRunner.Process(packetMap); + Assert.AreEqual(1, outputMap.At("out").Get()); + Assert.True(packetMap.isDisposed); } } #endregion @@ -102,38 +101,36 @@ public void Send_ShouldThrowException_When_CallbackIsNotSet() { using (var taskRunner = TaskRunner.Create(passThroughConfig)) { - using (var packetMap = new PacketMap()) - { - packetMap.Emplace("in", new IntPacket(1, new Timestamp(1))); - var exception = Assert.Throws(() => taskRunner.Send(packetMap)); - Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode); - } + var packetMap = new PacketMap(); + packetMap.Emplace("in", new IntPacket(1, new Timestamp(1))); + + var exception = Assert.Throws(() => taskRunner.Send(packetMap)); + Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode); + Assert.True(packetMap.isDisposed); } } [Test] public void Send_ShouldThrowException_When_InputIsInvalid() { - using (var taskRunner = TaskRunner.Create(passThroughConfig, HandlePassThroughResult)) + using (var taskRunner = TaskRunner.Create(passThroughConfig, 0, HandlePassThroughResult)) { - using (var packetMap = new PacketMap()) - { - var exception = Assert.Throws(() => taskRunner.Send(packetMap)); - Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode); - } + var packetMap = new PacketMap(); + var exception = Assert.Throws(() => taskRunner.Send(packetMap)); + Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode); + Assert.True(packetMap.isDisposed); } } [Test] public void Send_ShouldNotThrowException_When_InputIsValid() { - using (var taskRunner = TaskRunner.Create(passThroughConfig, HandlePassThroughResult)) + using (var taskRunner = TaskRunner.Create(passThroughConfig, 0, HandlePassThroughResult)) { - using (var packetMap = new PacketMap()) - { - packetMap.Emplace("in", new IntPacket(1, new Timestamp(1))); - Assert.DoesNotThrow(() => taskRunner.Send(packetMap)); - } + var packetMap = new PacketMap(); + packetMap.Emplace("in", new IntPacket(1, new Timestamp(1))); + Assert.DoesNotThrow(() => taskRunner.Send(packetMap)); + Assert.True(packetMap.isDisposed); } } #endregion @@ -195,7 +192,7 @@ public void GetGraphConfig_ShouldReturnCanonicalizedConfig() #endregion [AOT.MonoPInvokeCallback(typeof(TaskRunner.NativePacketsCallback))] - private static void HandlePassThroughResult(IntPtr statusPtr, IntPtr packetMapPtr) + private static void HandlePassThroughResult(int callbackId, IntPtr statusPtr, IntPtr packetMapPtr) { // Do nothing } diff --git a/mediapipe_api/tasks/cc/core/task_runner.cc b/mediapipe_api/tasks/cc/core/task_runner.cc index 44f257fd6..12cac050c 100644 --- a/mediapipe_api/tasks/cc/core/task_runner.cc +++ b/mediapipe_api/tasks/cc/core/task_runner.cc @@ -2,20 +2,21 @@ #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" -MpReturnCode mp_tasks_core_TaskRunner_Create__PKc_i_PF(const char* serialized_config, int size, NativePacketsCallback* packets_callback, +MpReturnCode mp_tasks_core_TaskRunner_Create__PKc_i_PF(const char* serialized_config, int size, + int callback_id, NativePacketsCallback* packets_callback, absl::Status** status_out, TaskRunner** task_runner_out) { TRY auto config = ParseFromStringAsProto(serialized_config, size); mediapipe::tasks::core::PacketsCallback callback = nullptr; if (packets_callback) { - callback = [packets_callback](absl::StatusOr status_or_packet_map) -> void { + callback = [callback_id, packets_callback](absl::StatusOr status_or_packet_map) -> void { auto status = status_or_packet_map.status(); if (!status.ok()) { - packets_callback(&status, nullptr); + packets_callback(callback_id, &status, nullptr); return; } auto value = status_or_packet_map.value(); - packets_callback(&status, &value); + packets_callback(callback_id, &status, &value); }; } @@ -41,7 +42,7 @@ void mp_tasks_core_TaskRunner__delete(TaskRunner* task_runner) { MpReturnCode mp_tasks_core_TaskRunner__Process__Ppm(TaskRunner* task_runner, PacketMap* inputs, absl::Status** status_out, PacketMap** value_out) { TRY - auto status_or_packet_map = task_runner->Process(*inputs); + auto status_or_packet_map = task_runner->Process(std::move(*inputs)); *status_out = new absl::Status{status_or_packet_map.status()}; if (status_or_packet_map.ok()) { *value_out = new PacketMap{status_or_packet_map.value()}; @@ -54,7 +55,7 @@ MpReturnCode mp_tasks_core_TaskRunner__Process__Ppm(TaskRunner* task_runner, Pac MpReturnCode mp_tasks_core_TaskRunner__Send__Ppm(TaskRunner* task_runner, PacketMap* inputs, absl::Status** status_out) { TRY - *status_out = new absl::Status{task_runner->Send(*inputs)}; + *status_out = new absl::Status{task_runner->Send(std::move(*inputs))}; RETURN_CODE(MpReturnCode::Success); CATCH_EXCEPTION } diff --git a/mediapipe_api/tasks/cc/core/task_runner.h b/mediapipe_api/tasks/cc/core/task_runner.h index fc05d0ff2..18ef8e199 100644 --- a/mediapipe_api/tasks/cc/core/task_runner.h +++ b/mediapipe_api/tasks/cc/core/task_runner.h @@ -19,9 +19,10 @@ using TaskRunner = mediapipe::tasks::core::TaskRunner; extern "C" { typedef std::map PacketMap; -typedef void NativePacketsCallback(absl::Status*, PacketMap*); +typedef void NativePacketsCallback(int, absl::Status*, PacketMap*); -MP_CAPI(MpReturnCode) mp_tasks_core_TaskRunner_Create__PKc_i_PF(const char* serialized_config, int size, NativePacketsCallback* packets_callback, +MP_CAPI(MpReturnCode) mp_tasks_core_TaskRunner_Create__PKc_i_PF(const char* serialized_config, int size, + int callback_id, NativePacketsCallback* packets_callback, absl::Status** status_out, TaskRunner** task_runner_out); MP_CAPI(void) mp_tasks_core_TaskRunner__delete(TaskRunner* task_runner);