Skip to content

Commit

Permalink
feat: implement PacketsCallbackTable (#971)
Browse files Browse the repository at this point in the history
* refactor: TaskRunner passes callback_id to the callback

* feat: implement PacketsCallbackTable

* refactor: dispose PacketMap when Process / Send called

* fix: packets_callback must be null in some cases
  • Loading branch information
homuler committed Jul 29, 2023
1 parent d8593f7 commit eeeec33
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ internal static partial class UnsafeNativeMethods
{
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_tasks_core_TaskRunner_Create__PKc_i_PF(byte[] serializedConfig, int size,
[MarshalAs(UnmanagedType.FunctionPtr)] Tasks.Core.TaskRunner.NativePacketsCallback packetsCallback,
int callbackId, [MarshalAs(UnmanagedType.FunctionPtr)] Tasks.Core.TaskRunner.NativePacketsCallback packetsCallback,
out IntPtr status, out IntPtr taskRunner);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) 2023 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System;
using UnityEngine;

namespace Mediapipe.Tasks.Core
{
internal class PacketsCallbackTable
{
private const int _MaxSize = 20;

private static int _Counter = 0;
private static readonly GlobalInstanceTable<int, TaskRunner.PacketsCallback> _Table = new GlobalInstanceTable<int, TaskRunner.PacketsCallback>(_MaxSize);

public static (int, TaskRunner.NativePacketsCallback) Add(TaskRunner.PacketsCallback callback)
{
if (callback == null)
{
return (-1, null);
}

var callbackId = _Counter++;
_Table.Add(callbackId, callback);
return (callbackId, InvokeCallbackIfFound);
}

public static bool TryGetValue(int id, out TaskRunner.PacketsCallback callback) => _Table.TryGetValue(id, out callback);

[AOT.MonoPInvokeCallback(typeof(TaskRunner.NativePacketsCallback))]
private static void InvokeCallbackIfFound(int callbackId, IntPtr statusPtr, IntPtr packetMapPtr)
{
// NOTE: if status is not OK, packetMap will be nullptr
if (packetMapPtr == IntPtr.Zero)
{
var status = new Status(statusPtr, false);
Debug.LogError(status.ToString());
return;
}

if (TryGetValue(callbackId, out var callback))
{
callback(new PacketMap(packetMapPtr, false));
}
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@ namespace Mediapipe.Tasks.Core
{
public class TaskRunner : MpResourceHandle
{
public delegate void NativePacketsCallback(IntPtr status, IntPtr packetMap);
public delegate void NativePacketsCallback(int name, IntPtr status, IntPtr packetMap);
public delegate void PacketsCallback(PacketMap packetMap);

public static TaskRunner Create(CalculatorGraphConfig config, NativePacketsCallback packetsCallback = null)
public static TaskRunner Create(CalculatorGraphConfig config, int callbackId = -1, NativePacketsCallback packetsCallback = null)
{
var bytes = config.ToByteArray();
UnsafeNativeMethods.mp_tasks_core_TaskRunner_Create__PKc_i_PF(bytes, bytes.Length, packetsCallback, out var statusPtr, out var taskRunnerPtr).Assert();

var status = new Status(statusPtr);
status.AssertOk();
UnsafeNativeMethods.mp_tasks_core_TaskRunner_Create__PKc_i_PF(bytes, bytes.Length, callbackId, packetsCallback, out var statusPtr, out var taskRunnerPtr).Assert();

AssertStatusOk(statusPtr);
return new TaskRunner(taskRunnerPtr);
}

Expand All @@ -34,39 +33,36 @@ protected override void DeleteMpPtr()
public PacketMap Process(PacketMap inputs)
{
UnsafeNativeMethods.mp_tasks_core_TaskRunner__Process__Ppm(mpPtr, inputs.mpPtr, out var statusPtr, out var packetMapPtr).Assert();
GC.KeepAlive(this);

var status = new Status(statusPtr);
status.AssertOk();
inputs.Dispose(); // respect move semantics

GC.KeepAlive(this);
AssertStatusOk(statusPtr);
return new PacketMap(packetMapPtr, true);
}

public void Send(PacketMap inputs)
{
UnsafeNativeMethods.mp_tasks_core_TaskRunner__Send__Ppm(mpPtr, inputs.mpPtr, out var statusPtr).Assert();
GC.KeepAlive(this);
inputs.Dispose(); // respect move semantics

var status = new Status(statusPtr);
status.AssertOk();
GC.KeepAlive(this);
AssertStatusOk(statusPtr);
}

public void Close()
{
UnsafeNativeMethods.mp_tasks_core_TaskRunner__Close(mpPtr, out var statusPtr).Assert();
GC.KeepAlive(this);

var status = new Status(statusPtr);
status.AssertOk();
AssertStatusOk(statusPtr);
}

public void Restart()
{
UnsafeNativeMethods.mp_tasks_core_TaskRunner__Restart(mpPtr, out var statusPtr).Assert();
GC.KeepAlive(this);

var status = new Status(statusPtr);
status.AssertOk();
AssertStatusOk(statusPtr);
}

public CalculatorGraphConfig GetGraphConfig(ExtensionRegistry extensionRegistry = null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,22 @@ public class BaseVisionTaskApi : IDisposable
protected BaseVisionTaskApi(
CalculatorGraphConfig graphConfig,
RunningMode runningMode,
Tasks.Core.TaskRunner.NativePacketsCallback packetCallback)
Tasks.Core.TaskRunner.PacketsCallback packetsCallback)
{
if (runningMode == RunningMode.LIVE_STREAM)
{
if (packetCallback == null)
if (packetsCallback == null)
{
throw new ArgumentException("The vision task is in live stream mode, a user-defined result callback must be provided.");
}
}
else if (packetCallback != null)
else if (packetsCallback != null)
{
throw new ArgumentException("The vision task is in image or video mode, a user-defined result callback should not be provided.");
}

_taskRunner = Tasks.Core.TaskRunner.Create(graphConfig, packetCallback);
var (callbackId, nativePacketsCallback) = Tasks.Core.PacketsCallbackTable.Add(packetsCallback);
_taskRunner = Tasks.Core.TaskRunner.Create(graphConfig, callbackId, nativePacketsCallback);
_runningMode = runningMode;
}

Expand Down Expand Up @@ -175,10 +176,7 @@ public void Close()
/// <summary>
/// Returns the canonicalized CalculatorGraphConfig of the underlying graph.
/// </summary>
public CalculatorGraphConfig GetGraphConfig()
{
return _taskRunner.GetGraphConfig();
}
public CalculatorGraphConfig GetGraphConfig() => _taskRunner.GetGraphConfig();

void IDisposable.Dispose()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,10 @@ public void Process_ShouldThrowException_When_InputIsInvalid()
{
using (var taskRunner = TaskRunner.Create(passThroughConfig))
{
using (var packetMap = new PacketMap())
{
var exception = Assert.Throws<BadStatusException>(() => taskRunner.Process(packetMap));
Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode);
}
var packetMap = new PacketMap();
var exception = Assert.Throws<BadStatusException>(() => taskRunner.Process(packetMap));
Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode);
Assert.True(packetMap.isDisposed);
}
}

Expand All @@ -86,12 +85,12 @@ public void Process_ShouldReturnOutput_When_InputIsValid()
{
using (var taskRunner = TaskRunner.Create(passThroughConfig))
{
using (var packetMap = new PacketMap())
{
packetMap.Emplace("in", new IntPacket(1));
var outputMap = taskRunner.Process(packetMap);
Assert.AreEqual(1, outputMap.At<IntPacket, int>("out").Get());
}
var packetMap = new PacketMap();
packetMap.Emplace("in", new IntPacket(1));

var outputMap = taskRunner.Process(packetMap);
Assert.AreEqual(1, outputMap.At<IntPacket, int>("out").Get());
Assert.True(packetMap.isDisposed);
}
}
#endregion
Expand All @@ -102,38 +101,36 @@ public void Send_ShouldThrowException_When_CallbackIsNotSet()
{
using (var taskRunner = TaskRunner.Create(passThroughConfig))
{
using (var packetMap = new PacketMap())
{
packetMap.Emplace("in", new IntPacket(1, new Timestamp(1)));
var exception = Assert.Throws<BadStatusException>(() => taskRunner.Send(packetMap));
Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode);
}
var packetMap = new PacketMap();
packetMap.Emplace("in", new IntPacket(1, new Timestamp(1)));

var exception = Assert.Throws<BadStatusException>(() => taskRunner.Send(packetMap));
Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode);
Assert.True(packetMap.isDisposed);
}
}

[Test]
public void Send_ShouldThrowException_When_InputIsInvalid()
{
using (var taskRunner = TaskRunner.Create(passThroughConfig, HandlePassThroughResult))
using (var taskRunner = TaskRunner.Create(passThroughConfig, 0, HandlePassThroughResult))
{
using (var packetMap = new PacketMap())
{
var exception = Assert.Throws<BadStatusException>(() => taskRunner.Send(packetMap));
Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode);
}
var packetMap = new PacketMap();
var exception = Assert.Throws<BadStatusException>(() => taskRunner.Send(packetMap));
Assert.AreEqual(StatusCode.InvalidArgument, exception.statusCode);
Assert.True(packetMap.isDisposed);
}
}

[Test]
public void Send_ShouldNotThrowException_When_InputIsValid()
{
using (var taskRunner = TaskRunner.Create(passThroughConfig, HandlePassThroughResult))
using (var taskRunner = TaskRunner.Create(passThroughConfig, 0, HandlePassThroughResult))
{
using (var packetMap = new PacketMap())
{
packetMap.Emplace("in", new IntPacket(1, new Timestamp(1)));
Assert.DoesNotThrow(() => taskRunner.Send(packetMap));
}
var packetMap = new PacketMap();
packetMap.Emplace("in", new IntPacket(1, new Timestamp(1)));
Assert.DoesNotThrow(() => taskRunner.Send(packetMap));
Assert.True(packetMap.isDisposed);
}
}
#endregion
Expand Down Expand Up @@ -195,7 +192,7 @@ public void GetGraphConfig_ShouldReturnCanonicalizedConfig()
#endregion

[AOT.MonoPInvokeCallback(typeof(TaskRunner.NativePacketsCallback))]
private static void HandlePassThroughResult(IntPtr statusPtr, IntPtr packetMapPtr)
private static void HandlePassThroughResult(int callbackId, IntPtr statusPtr, IntPtr packetMapPtr)
{
// Do nothing
}
Expand Down
13 changes: 7 additions & 6 deletions mediapipe_api/tasks/cc/core/task_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@

#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"

MpReturnCode mp_tasks_core_TaskRunner_Create__PKc_i_PF(const char* serialized_config, int size, NativePacketsCallback* packets_callback,
MpReturnCode mp_tasks_core_TaskRunner_Create__PKc_i_PF(const char* serialized_config, int size,
int callback_id, NativePacketsCallback* packets_callback,
absl::Status** status_out, TaskRunner** task_runner_out) {
TRY
auto config = ParseFromStringAsProto<mediapipe::CalculatorGraphConfig>(serialized_config, size);
mediapipe::tasks::core::PacketsCallback callback = nullptr;
if (packets_callback) {
callback = [packets_callback](absl::StatusOr<PacketMap> status_or_packet_map) -> void {
callback = [callback_id, packets_callback](absl::StatusOr<PacketMap> status_or_packet_map) -> void {
auto status = status_or_packet_map.status();
if (!status.ok()) {
packets_callback(&status, nullptr);
packets_callback(callback_id, &status, nullptr);
return;
}
auto value = status_or_packet_map.value();
packets_callback(&status, &value);
packets_callback(callback_id, &status, &value);
};
}

Expand All @@ -41,7 +42,7 @@ void mp_tasks_core_TaskRunner__delete(TaskRunner* task_runner) {

MpReturnCode mp_tasks_core_TaskRunner__Process__Ppm(TaskRunner* task_runner, PacketMap* inputs, absl::Status** status_out, PacketMap** value_out) {
TRY
auto status_or_packet_map = task_runner->Process(*inputs);
auto status_or_packet_map = task_runner->Process(std::move(*inputs));
*status_out = new absl::Status{status_or_packet_map.status()};
if (status_or_packet_map.ok()) {
*value_out = new PacketMap{status_or_packet_map.value()};
Expand All @@ -54,7 +55,7 @@ MpReturnCode mp_tasks_core_TaskRunner__Process__Ppm(TaskRunner* task_runner, Pac

MpReturnCode mp_tasks_core_TaskRunner__Send__Ppm(TaskRunner* task_runner, PacketMap* inputs, absl::Status** status_out) {
TRY
*status_out = new absl::Status{task_runner->Send(*inputs)};
*status_out = new absl::Status{task_runner->Send(std::move(*inputs))};
RETURN_CODE(MpReturnCode::Success);
CATCH_EXCEPTION
}
Expand Down
5 changes: 3 additions & 2 deletions mediapipe_api/tasks/cc/core/task_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ using TaskRunner = mediapipe::tasks::core::TaskRunner;
extern "C" {

typedef std::map<std::string, mediapipe::Packet> PacketMap;
typedef void NativePacketsCallback(absl::Status*, PacketMap*);
typedef void NativePacketsCallback(int, absl::Status*, PacketMap*);

MP_CAPI(MpReturnCode) mp_tasks_core_TaskRunner_Create__PKc_i_PF(const char* serialized_config, int size, NativePacketsCallback* packets_callback,
MP_CAPI(MpReturnCode) mp_tasks_core_TaskRunner_Create__PKc_i_PF(const char* serialized_config, int size,
int callback_id, NativePacketsCallback* packets_callback,
absl::Status** status_out, TaskRunner** task_runner_out);
MP_CAPI(void) mp_tasks_core_TaskRunner__delete(TaskRunner* task_runner);

Expand Down

0 comments on commit eeeec33

Please sign in to comment.