Skip to content

Commit

Permalink
feat: OutputStream API (#516)
Browse files Browse the repository at this point in the history
* feat!(plugin): NativePacketCallback receives stream_id

* feat: OutputStream supports callbacks

* refactor: place OutputStream.cs under Packages/

* refactor: use InterLocked instead of lock

* fix: swap poseLandmarks and faceLandmarks

* fix: AddListener should not be called in synchronous mode
  • Loading branch information
homuler committed Apr 7, 2022
1 parent 9f5d91d commit 2679000
Show file tree
Hide file tree
Showing 36 changed files with 842 additions and 951 deletions.
36 changes: 0 additions & 36 deletions Assets/Mediapipe/Samples/Common/Scripts/GraphRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -206,42 +206,6 @@ protected void AddTextureFrameToInputStream(string streamName, TextureFrame text
return result || allowBlock || stream.ResetTimestampIfTimedOut(currentTimestampMicrosec, timeoutMicrosec);
}

protected static bool TryGetGraphRunner(IntPtr graphPtr, out GraphRunner graphRunner)
{
var isInstanceIdFound = _NameTable.TryGetValue(graphPtr, out var instanceId);

if (isInstanceIdFound)
{
return _InstanceTable.TryGetValue(instanceId, out graphRunner);
}
graphRunner = null;
return false;
}

protected static Status InvokeIfGraphRunnerFound<T>(IntPtr graphPtr, IntPtr packetPtr, Action<T, IntPtr> action) where T : GraphRunner
{
try
{
var isFound = TryGetGraphRunner(graphPtr, out var graphRunner);
if (!isFound)
{
return Status.FailedPrecondition("Graph runner is not found");
}
var graph = (T)graphRunner;
action(graph, packetPtr);
return Status.Ok();
}
catch (Exception e)
{
return Status.FailedPrecondition(e.ToString());
}
}

protected static Status InvokeIfGraphRunnerFound<T>(IntPtr graphPtr, Action<T> action) where T : GraphRunner
{
return InvokeIfGraphRunnerFound<T>(graphPtr, IntPtr.Zero, (graph, ptr) => { action(graph); });
}

protected long GetCurrentTimestampMicrosec()
{
return _stopwatch == null || !_stopwatch.IsRunning ? -1 : _stopwatch.ElapsedTicks / (TimeSpan.TicksPerMillisecond / 1000);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ private IEnumerator Run()
yield break;
}

graphRunner.StartRun(imageSource);
OnStartRun();
graphRunner.StartRun(imageSource);

var waitWhilePausing = new WaitWhile(() => isPaused);

Expand Down
44 changes: 11 additions & 33 deletions Assets/Mediapipe/Samples/Scenes/Box Tracking/BoxTrackingGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

using System;
using System.Collections.Generic;
using UnityEngine.Events;

namespace Mediapipe.Unity.BoxTracking
{
public class BoxTrackingGraph : GraphRunner
{
#pragma warning disable IDE1006 // UnityEvent is PascalCase
public UnityEvent<List<Detection>> OnTrackedDetectionsOutput = new UnityEvent<List<Detection>>();
#pragma warning restore IDE1006
public event EventHandler<OutputEventArgs<List<Detection>>> OnTrackedDetectionsOutput
{
add => _trackedDetectionsStream.AddListener(value);
remove => _trackedDetectionsStream.RemoveListener(value);
}

private const string _InputStreamName = "input_video";
private const string _TrackedDetectionsStreamName = "tracked_detections";
Expand All @@ -26,18 +27,14 @@ public override void StartRun(ImageSource imageSource)
{
_trackedDetectionsStream.StartPolling().AssertOk();
}
else
{
_trackedDetectionsStream.AddListener(TrackedDetectionsCallback).AssertOk();
}
StartRun(BuildSidePacket(imageSource));
}

public override void Stop()
{
base.Stop();
OnTrackedDetectionsOutput.RemoveAllListeners();
_trackedDetectionsStream.RemoveAllListeners();
_trackedDetectionsStream = null;
base.Stop();
}

public void AddTextureFrameToInputStream(TextureFrame textureFrame)
Expand All @@ -47,27 +44,7 @@ public void AddTextureFrameToInputStream(TextureFrame textureFrame)

public bool TryGetNext(out List<Detection> trackedDetections, bool allowBlock = true)
{
if (TryGetNext(_trackedDetectionsStream, out trackedDetections, allowBlock, GetCurrentTimestampMicrosec()))
{
OnTrackedDetectionsOutput.Invoke(trackedDetections);
return true;
}
return false;
}

[AOT.MonoPInvokeCallback(typeof(CalculatorGraph.NativePacketCallback))]
private static IntPtr TrackedDetectionsCallback(IntPtr graphPtr, IntPtr packetPtr)
{
return InvokeIfGraphRunnerFound<BoxTrackingGraph>(graphPtr, packetPtr, (boxTrackingGraph, ptr) =>
{
using (var packet = new DetectionVectorPacket(ptr, false))
{
if (boxTrackingGraph._trackedDetectionsStream.TryGetPacketValue(packet, out var value, boxTrackingGraph.timeoutMicrosec))
{
boxTrackingGraph.OnTrackedDetectionsOutput.Invoke(value);
}
}
}).mpPtr;
return TryGetNext(_trackedDetectionsStream, out trackedDetections, allowBlock, GetCurrentTimestampMicrosec());
}

protected override IList<WaitForResult> RequestDependentAssets()
Expand All @@ -82,11 +59,12 @@ protected override Status ConfigureCalculatorGraph(CalculatorGraphConfig config)
{
if (runningMode == RunningMode.NonBlockingSync)
{
_trackedDetectionsStream = new OutputStream<DetectionVectorPacket, List<Detection>>(calculatorGraph, _TrackedDetectionsStreamName, config.AddPacketPresenceCalculator(_TrackedDetectionsStreamName));
_trackedDetectionsStream = new OutputStream<DetectionVectorPacket, List<Detection>>(
calculatorGraph, _TrackedDetectionsStreamName, config.AddPacketPresenceCalculator(_TrackedDetectionsStreamName), timeoutMicrosec);
}
else
{
_trackedDetectionsStream = new OutputStream<DetectionVectorPacket, List<Detection>>(calculatorGraph, _TrackedDetectionsStreamName, true);
_trackedDetectionsStream = new OutputStream<DetectionVectorPacket, List<Detection>>(calculatorGraph, _TrackedDetectionsStreamName, true, timeoutMicrosec);
}
return calculatorGraph.Initialize(config);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// https://opensource.org/licenses/MIT.

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

namespace Mediapipe.Unity.BoxTracking
Expand All @@ -15,7 +16,11 @@ public class BoxTrackingSolution : ImageSourceSolution<BoxTrackingGraph>

protected override void OnStartRun()
{
graphRunner.OnTrackedDetectionsOutput.AddListener(_trackedDetectionsAnnotationController.DrawLater);
if (!runningMode.IsSynchronous())
{
graphRunner.OnTrackedDetectionsOutput += OnTrackedDetectionsOutput;
}

SetupAnnotationController(_trackedDetectionsAnnotationController, ImageSourceProvider.ImageSource);
}

Expand All @@ -26,14 +31,23 @@ protected override void AddTextureFrameToInputStream(TextureFrame textureFrame)

protected override IEnumerator WaitForNextValue()
{
List<Detection> trackedDetections = null;

if (runningMode == RunningMode.Sync)
{
var _ = graphRunner.TryGetNext(out var _, true);
var _ = graphRunner.TryGetNext(out trackedDetections, true);
}
else if (runningMode == RunningMode.NonBlockingSync)
{
yield return new WaitUntil(() => graphRunner.TryGetNext(out var _, false));
yield return new WaitUntil(() => graphRunner.TryGetNext(out trackedDetections, false));
}

_trackedDetectionsAnnotationController.DrawNow(trackedDetections);
}

private void OnTrackedDetectionsOutput(object stream, OutputEventArgs<List<Detection>> eventArgs)
{
_trackedDetectionsAnnotationController.DrawLater(eventArgs.value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using UnityEngine.Events;

using Google.Protobuf;

Expand All @@ -30,9 +29,11 @@ public float minDetectionConfidence
set => _minDetectionConfidence = Mathf.Clamp01(value);
}

#pragma warning disable IDE1006
public UnityEvent<List<Detection>> OnFaceDetectionsOutput = new UnityEvent<List<Detection>>();
#pragma warning restore IDE1006
public event EventHandler<OutputEventArgs<List<Detection>>> OnFaceDetectionsOutput
{
add => _faceDetectionsStream.AddListener(value);
remove => _faceDetectionsStream.RemoveListener(value);
}

private const string _InputStreamName = "input_video";
private const string _FaceDetectionsStreamName = "face_detections";
Expand All @@ -44,18 +45,14 @@ public override void StartRun(ImageSource imageSource)
{
_faceDetectionsStream.StartPolling().AssertOk();
}
else
{
_faceDetectionsStream.AddListener(FaceDetectionsCallback).AssertOk();
}
StartRun(BuildSidePacket(imageSource));
}

public override void Stop()
{
base.Stop();
OnFaceDetectionsOutput.RemoveAllListeners();
_faceDetectionsStream.RemoveAllListeners();
_faceDetectionsStream = null;
base.Stop();
}

public void AddTextureFrameToInputStream(TextureFrame textureFrame)
Expand All @@ -65,27 +62,7 @@ public void AddTextureFrameToInputStream(TextureFrame textureFrame)

public bool TryGetNext(out List<Detection> faceDetections, bool allowBlock = true)
{
if (TryGetNext(_faceDetectionsStream, out faceDetections, allowBlock, GetCurrentTimestampMicrosec()))
{
OnFaceDetectionsOutput.Invoke(faceDetections);
return true;
}
return false;
}

[AOT.MonoPInvokeCallback(typeof(CalculatorGraph.NativePacketCallback))]
private static IntPtr FaceDetectionsCallback(IntPtr graphPtr, IntPtr packetPtr)
{
return InvokeIfGraphRunnerFound<FaceDetectionGraph>(graphPtr, packetPtr, (faceDetectionGraph, ptr) =>
{
using (var packet = new DetectionVectorPacket(ptr, false))
{
if (faceDetectionGraph._faceDetectionsStream.TryGetPacketValue(packet, out var value, faceDetectionGraph.timeoutMicrosec))
{
faceDetectionGraph.OnFaceDetectionsOutput.Invoke(value);
}
}
}).mpPtr;
return TryGetNext(_faceDetectionsStream, out faceDetections, allowBlock, GetCurrentTimestampMicrosec());
}

protected override IList<WaitForResult> RequestDependentAssets()
Expand All @@ -100,11 +77,12 @@ protected override Status ConfigureCalculatorGraph(CalculatorGraphConfig config)
{
if (runningMode == RunningMode.NonBlockingSync)
{
_faceDetectionsStream = new OutputStream<DetectionVectorPacket, List<Detection>>(calculatorGraph, _FaceDetectionsStreamName, config.AddPacketPresenceCalculator(_FaceDetectionsStreamName));
_faceDetectionsStream = new OutputStream<DetectionVectorPacket, List<Detection>>(
calculatorGraph, _FaceDetectionsStreamName, config.AddPacketPresenceCalculator(_FaceDetectionsStreamName), timeoutMicrosec);
}
else
{
_faceDetectionsStream = new OutputStream<DetectionVectorPacket, List<Detection>>(calculatorGraph, _FaceDetectionsStreamName, true);
_faceDetectionsStream = new OutputStream<DetectionVectorPacket, List<Detection>>(calculatorGraph, _FaceDetectionsStreamName, true, timeoutMicrosec);
}

using (var validatedGraphConfig = new ValidatedGraphConfig())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// https://opensource.org/licenses/MIT.

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

namespace Mediapipe.Unity.FaceDetection
Expand All @@ -21,7 +22,11 @@ public FaceDetectionGraph.ModelType modelType

protected override void OnStartRun()
{
graphRunner.OnFaceDetectionsOutput.AddListener(_faceDetectionsAnnotationController.DrawLater);
if (!runningMode.IsSynchronous())
{
graphRunner.OnFaceDetectionsOutput += OnFaceDetectionsOutput;
}

SetupAnnotationController(_faceDetectionsAnnotationController, ImageSourceProvider.ImageSource);
}

Expand All @@ -32,14 +37,23 @@ protected override void AddTextureFrameToInputStream(TextureFrame textureFrame)

protected override IEnumerator WaitForNextValue()
{
List<Detection> faceDetections = null;

if (runningMode == RunningMode.Sync)
{
var _ = graphRunner.TryGetNext(out var _, true);
var _ = graphRunner.TryGetNext(out faceDetections, true);
}
else if (runningMode == RunningMode.NonBlockingSync)
{
yield return new WaitUntil(() => graphRunner.TryGetNext(out var _, false));
yield return new WaitUntil(() => graphRunner.TryGetNext(out faceDetections, false));
}

_faceDetectionsAnnotationController.DrawNow(faceDetections);
}

private void OnFaceDetectionsOutput(object stream, OutputEventArgs<List<Detection>> eventArgs)
{
_faceDetectionsAnnotationController.DrawLater(eventArgs.value);
}
}
}
Loading

0 comments on commit 2679000

Please sign in to comment.