Skip to content

Commit

Permalink
feat!: strongly-typed Packet, once again (#1120)
Browse files Browse the repository at this point in the history
  • Loading branch information
homuler committed Jan 8, 2024
1 parent a6e67e3 commit 4f3668e
Show file tree
Hide file tree
Showing 38 changed files with 1,191 additions and 897 deletions.
162 changes: 152 additions & 10 deletions Assets/MediaPipeUnity/Samples/Common/Scripts/GraphRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using UnityEngine;

using Stopwatch = System.Diagnostics.Stopwatch;
Expand Down Expand Up @@ -179,7 +180,7 @@ public virtual void Stop()
}
}

protected void AddPacketToInputStream(string streamName, Packet packet)
protected void AddPacketToInputStream<T>(string streamName, Packet<T> packet)
{
calculatorGraph.AddPacketToInputStream(streamName, packet);
}
Expand All @@ -201,18 +202,12 @@ protected void AddTextureFrameToInputStream(string streamName, TextureFrame text
AddPacketToInputStream(streamName, Packet.CreateImageFrameAt(imageFrame, latestTimestamp));
}

protected void AssertResult(params OutputStream.NextResult[] results)
protected bool TryGetValue<T>(Packet<T> packet, out T value)
{
foreach (var result in results)
{
if (!result.ok)
{
throw new Exception("Failed to get the next packet");
}
}
return TryGetValue(packet, out value, (packet) => packet.Get());
}

protected bool TryGetValue<T>(Packet packet, out T value, Func<Packet, T> getter)
protected bool TryGetValue<T>(Packet<T> packet, out T value, Func<Packet<T>, T> getter)
{
if (packet == null)
{
Expand All @@ -223,6 +218,153 @@ protected bool TryGetValue<T>(Packet packet, out T value, Func<Packet, T> getter
return true;
}

protected void AssertResult<T>(OutputStream<T>.NextResult result)
{
if (!result.ok)
{
throw new Exception("Failed to get the next packet");
}
}

protected void AssertResult<T1, T2>((OutputStream<T1>.NextResult, OutputStream<T2>.NextResult) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
}

protected void AssertResult<T1, T2, T3>((OutputStream<T1>.NextResult, OutputStream<T2>.NextResult, OutputStream<T3>.NextResult) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
}

protected void AssertResult<T1, T2, T3, T4>((OutputStream<T1>.NextResult, OutputStream<T2>.NextResult, OutputStream<T3>.NextResult, OutputStream<T4>.NextResult) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
AssertResult(result.Item4);
}

protected void AssertResult<T1, T2, T3, T4, T5>(
(
OutputStream<T1>.NextResult,
OutputStream<T2>.NextResult,
OutputStream<T3>.NextResult,
OutputStream<T4>.NextResult,
OutputStream<T5>.NextResult
) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
AssertResult(result.Item4);
AssertResult(result.Item5);
}

protected void AssertResult<T1, T2, T3, T4, T5, T6>(
(
OutputStream<T1>.NextResult,
OutputStream<T2>.NextResult,
OutputStream<T3>.NextResult,
OutputStream<T4>.NextResult,
OutputStream<T5>.NextResult,
OutputStream<T6>.NextResult
) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
AssertResult(result.Item4);
AssertResult(result.Item5);
AssertResult(result.Item6);
}

protected void AssertResult<T1, T2, T3, T4, T5, T6, T7>(
(
OutputStream<T1>.NextResult,
OutputStream<T2>.NextResult,
OutputStream<T3>.NextResult,
OutputStream<T4>.NextResult,
OutputStream<T5>.NextResult,
OutputStream<T6>.NextResult,
OutputStream<T7>.NextResult
) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
AssertResult(result.Item4);
AssertResult(result.Item5);
AssertResult(result.Item6);
AssertResult(result.Item7);
}

protected void AssertResult<T1, T2, T3, T4, T5, T6, T7, T8>(
(
OutputStream<T1>.NextResult,
OutputStream<T2>.NextResult,
OutputStream<T3>.NextResult,
OutputStream<T4>.NextResult,
OutputStream<T5>.NextResult,
OutputStream<T6>.NextResult,
OutputStream<T7>.NextResult,
OutputStream<T8>.NextResult
) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
AssertResult(result.Item4);
AssertResult(result.Item5);
AssertResult(result.Item6);
AssertResult(result.Item7);
AssertResult(result.Item8);
}

protected async Task<(T1, T2)> WhenAll<T1, T2>(Task<T1> task1, Task<T2> task2)
{
await Task.WhenAll(task1, task2);
return (task1.Result, task2.Result);
}

protected async Task<(T1, T2, T3)> WhenAll<T1, T2, T3>(Task<T1> task1, Task<T2> task2, Task<T3> task3)
{
await Task.WhenAll(task1, task2, task3);
return (task1.Result, task2.Result, task3.Result);
}

protected async Task<(T1, T2, T3, T4)> WhenAll<T1, T2, T3, T4>(Task<T1> task1, Task<T2> task2, Task<T3> task3, Task<T4> task4)
{
await Task.WhenAll(task1, task2, task3, task4);
return (task1.Result, task2.Result, task3.Result, task4.Result);
}

protected async Task<(T1, T2, T3, T4, T5)> WhenAll<T1, T2, T3, T4, T5>(Task<T1> task1, Task<T2> task2, Task<T3> task3, Task<T4> task4, Task<T5> task5)
{
await Task.WhenAll(task1, task2, task3, task4, task5);
return (task1.Result, task2.Result, task3.Result, task4.Result, task5.Result);
}

protected async Task<(T1, T2, T3, T4, T5, T6)> WhenAll<T1, T2, T3, T4, T5, T6>(Task<T1> task1, Task<T2> task2, Task<T3> task3, Task<T4> task4, Task<T5> task5, Task<T6> task6)
{
await Task.WhenAll(task1, task2, task3, task4, task5, task6);
return (task1.Result, task2.Result, task3.Result, task4.Result, task5.Result, task6.Result);
}

protected async Task<(T1, T2, T3, T4, T5, T6, T7)> WhenAll<T1, T2, T3, T4, T5, T6, T7>(Task<T1> task1, Task<T2> task2, Task<T3> task3, Task<T4> task4, Task<T5> task5, Task<T6> task6, Task<T7> task7)
{
await Task.WhenAll(task1, task2, task3, task4, task5, task6, task7);
return (task1.Result, task2.Result, task3.Result, task4.Result, task5.Result, task6.Result, task7.Result);
}

protected async Task<(T1, T2, T3, T4, T5, T6, T7, T8)> WhenAll<T1, T2, T3, T4, T5, T6, T7, T8>(Task<T1> task1, Task<T2> task2, Task<T3> task3, Task<T4> task4, Task<T5> task5, Task<T6> task6, Task<T7> task7, Task<T8> task8)
{
await Task.WhenAll(task1, task2, task3, task4, task5, task6, task7, task8);
return (task1.Result, task2.Result, task3.Result, task4.Result, task5.Result, task6.Result, task7.Result, task8.Result);
}

protected long GetCurrentTimestampMicrosec()
{
return _stopwatch == null || !_stopwatch.IsRunning ? -1 : _stopwatch.ElapsedTicks / (TimeSpan.TicksPerMillisecond / 1000);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ public float minDetectionConfidence
set => _minDetectionConfidence = Mathf.Clamp01(value);
}

public event EventHandler<OutputStream.OutputEventArgs> OnFaceDetectionsOutput
public event EventHandler<OutputStream<List<Detection>>.OutputEventArgs> OnFaceDetectionsOutput
{
add => _faceDetectionsStream.AddListener(value, timeoutMicrosec);
remove => _faceDetectionsStream.RemoveListener(value);
}

private const string _InputStreamName = "input_video";
private const string _FaceDetectionsStreamName = "face_detections";
private OutputStream _faceDetectionsStream;
private OutputStream<List<Detection>> _faceDetectionsStream;

public override void StartRun(ImageSource imageSource)
{
Expand Down Expand Up @@ -82,7 +82,7 @@ protected override IList<WaitForResult> RequestDependentAssets()

protected override void ConfigureCalculatorGraph(CalculatorGraphConfig config)
{
_faceDetectionsStream = new OutputStream(calculatorGraph, _FaceDetectionsStreamName, true);
_faceDetectionsStream = new OutputStream<List<Detection>>(calculatorGraph, _FaceDetectionsStreamName, true);
Debug.Log(timeoutMicrosec);

var faceDetectionCalculators = config.Node.Where((node) => node.Calculator.StartsWith("FaceDetection")).ToList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// https://opensource.org/licenses/MIT.

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

namespace Mediapipe.Unity.Sample.FaceDetection
Expand Down Expand Up @@ -48,7 +49,7 @@ protected override IEnumerator WaitForNextValue()
_faceDetectionsAnnotationController.DrawNow(task.Result);
}

private void OnFaceDetectionsOutput(object stream, OutputStream.OutputEventArgs eventArgs)
private void OnFaceDetectionsOutput(object stream, OutputStream<List<Detection>>.OutputEventArgs eventArgs)
{
var packet = eventArgs.packet;
var value = packet == null ? default : packet.GetProtoList(Detection.Parser);
Expand Down
34 changes: 17 additions & 17 deletions Assets/MediaPipeUnity/Samples/Scenes/Face Mesh/FaceMeshGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,25 @@ public float minTrackingConfidence
set => _minTrackingConfidence = Mathf.Clamp01(value);
}

public event EventHandler<OutputStream.OutputEventArgs> OnFaceDetectionsOutput
public event EventHandler<OutputStream<List<Detection>>.OutputEventArgs> OnFaceDetectionsOutput
{
add => _faceDetectionsStream.AddListener(value, timeoutMicrosec);
remove => _faceDetectionsStream.RemoveListener(value);
}

public event EventHandler<OutputStream.OutputEventArgs> OnMultiFaceLandmarksOutput
public event EventHandler<OutputStream<List<NormalizedLandmarkList>>.OutputEventArgs> OnMultiFaceLandmarksOutput
{
add => _multiFaceLandmarksStream.AddListener(value, timeoutMicrosec);
remove => _multiFaceLandmarksStream.RemoveListener(value);
}

public event EventHandler<OutputStream.OutputEventArgs> OnFaceRectsFromLandmarksOutput
public event EventHandler<OutputStream<List<NormalizedRect>>.OutputEventArgs> OnFaceRectsFromLandmarksOutput
{
add => _faceRectsFromLandmarksStream.AddListener(value, timeoutMicrosec);
remove => _faceRectsFromLandmarksStream.RemoveListener(value);
}

public event EventHandler<OutputStream.OutputEventArgs> OnFaceRectsFromDetectionsOutput
public event EventHandler<OutputStream<List<NormalizedRect>>.OutputEventArgs> OnFaceRectsFromDetectionsOutput
{
add => _faceRectsFromDetectionsStream.AddListener(value, timeoutMicrosec);
remove => _faceRectsFromDetectionsStream.RemoveListener(value);
Expand All @@ -80,10 +80,10 @@ public float minTrackingConfidence
private const string _FaceRectsFromLandmarksStreamName = "face_rects_from_landmarks";
private const string _FaceRectsFromDetectionsStreamName = "face_rects_from_detections";

private OutputStream _faceDetectionsStream;
private OutputStream _multiFaceLandmarksStream;
private OutputStream _faceRectsFromLandmarksStream;
private OutputStream _faceRectsFromDetectionsStream;
private OutputStream<List<Detection>> _faceDetectionsStream;
private OutputStream<List<NormalizedLandmarkList>> _multiFaceLandmarksStream;
private OutputStream<List<NormalizedRect>> _faceRectsFromLandmarksStream;
private OutputStream<List<NormalizedRect>> _faceRectsFromDetectionsStream;

public override void StartRun(ImageSource imageSource)
{
Expand Down Expand Up @@ -117,27 +117,27 @@ public void AddTextureFrameToInputStream(TextureFrame textureFrame)

public async Task<FaceMeshResult> WaitNext()
{
var results = await Task.WhenAll(
var results = await WhenAll(
_faceDetectionsStream.WaitNextAsync(),
_multiFaceLandmarksStream.WaitNextAsync(),
_faceRectsFromLandmarksStream.WaitNextAsync(),
_faceRectsFromDetectionsStream.WaitNextAsync()
);
AssertResult(results);

_ = TryGetValue(results[0].packet, out var faceDetections, (packet) =>
_ = TryGetValue(results.Item1.packet, out var faceDetections, (packet) =>
{
return packet.GetProtoList(Detection.Parser);
});
_ = TryGetValue(results[1].packet, out var multiFaceLandmarks, (packet) =>
_ = TryGetValue(results.Item2.packet, out var multiFaceLandmarks, (packet) =>
{
return packet.GetProtoList(NormalizedLandmarkList.Parser);
});
_ = TryGetValue(results[2].packet, out var faceRectsFromLandmarks, (packet) =>
_ = TryGetValue(results.Item3.packet, out var faceRectsFromLandmarks, (packet) =>
{
return packet.GetProtoList(NormalizedRect.Parser);
});
_ = TryGetValue(results[3].packet, out var faceRectsFromDetections, (packet) =>
_ = TryGetValue(results.Item4.packet, out var faceRectsFromDetections, (packet) =>
{
return packet.GetProtoList(NormalizedRect.Parser);
});
Expand All @@ -147,10 +147,10 @@ public async Task<FaceMeshResult> WaitNext()

protected override void ConfigureCalculatorGraph(CalculatorGraphConfig config)
{
_faceDetectionsStream = new OutputStream(calculatorGraph, _FaceDetectionsStreamName, true);
_multiFaceLandmarksStream = new OutputStream(calculatorGraph, _MultiFaceLandmarksStreamName, true);
_faceRectsFromLandmarksStream = new OutputStream(calculatorGraph, _FaceRectsFromLandmarksStreamName, true);
_faceRectsFromDetectionsStream = new OutputStream(calculatorGraph, _FaceRectsFromDetectionsStreamName, true);
_faceDetectionsStream = new OutputStream<List<Detection>>(calculatorGraph, _FaceDetectionsStreamName, true);
_multiFaceLandmarksStream = new OutputStream<List<NormalizedLandmarkList>>(calculatorGraph, _MultiFaceLandmarksStreamName, true);
_faceRectsFromLandmarksStream = new OutputStream<List<NormalizedRect>>(calculatorGraph, _FaceRectsFromLandmarksStreamName, true);
_faceRectsFromDetectionsStream = new OutputStream<List<NormalizedRect>>(calculatorGraph, _FaceRectsFromDetectionsStreamName, true);

using (var validatedGraphConfig = new ValidatedGraphConfig())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// https://opensource.org/licenses/MIT.

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

namespace Mediapipe.Unity.Sample.FaceMesh
Expand Down Expand Up @@ -74,28 +75,28 @@ protected override IEnumerator WaitForNextValue()
_faceRectsFromDetectionsAnnotationController.DrawNow(result.faceRectsFromDetections);
}

private void OnFaceDetectionsOutput(object stream, OutputStream.OutputEventArgs eventArgs)
private void OnFaceDetectionsOutput(object stream, OutputStream<List<Detection>>.OutputEventArgs eventArgs)
{
var packet = eventArgs.packet;
var value = packet == null ? default : packet.GetProtoList(Detection.Parser);
_faceDetectionsAnnotationController.DrawLater(value);
}

private void OnMultiFaceLandmarksOutput(object stream, OutputStream.OutputEventArgs eventArgs)
private void OnMultiFaceLandmarksOutput(object stream, OutputStream<List<NormalizedLandmarkList>>.OutputEventArgs eventArgs)
{
var packet = eventArgs.packet;
var value = packet == null ? default : packet.GetProtoList(NormalizedLandmarkList.Parser);
_multiFaceLandmarksAnnotationController.DrawLater(value);
}

private void OnFaceRectsFromLandmarksOutput(object stream, OutputStream.OutputEventArgs eventArgs)
private void OnFaceRectsFromLandmarksOutput(object stream, OutputStream<List<NormalizedRect>>.OutputEventArgs eventArgs)
{
var packet = eventArgs.packet;
var value = packet == null ? default : packet.GetProtoList(NormalizedRect.Parser);
_faceRectsFromLandmarksAnnotationController.DrawLater(value);
}

private void OnFaceRectsFromDetectionsOutput(object stream, OutputStream.OutputEventArgs eventArgs)
private void OnFaceRectsFromDetectionsOutput(object stream, OutputStream<List<NormalizedRect>>.OutputEventArgs eventArgs)
{
var packet = eventArgs.packet;
var value = packet == null ? default : packet.GetProtoList(NormalizedRect.Parser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Mediapipe.Unity.Sample.HairSegmentation
{
public class HairSegmentationGraph : GraphRunner
{
public event EventHandler<OutputStream.OutputEventArgs> OnHairMaskOutput
public event EventHandler<OutputStream<ImageFrame>.OutputEventArgs> OnHairMaskOutput
{
add => _hairMaskStream.AddListener(value, timeoutMicrosec);
remove => _hairMaskStream.RemoveListener(value);
Expand All @@ -25,7 +25,7 @@ public class HairSegmentationGraph : GraphRunner

private const string _InputStreamName = "input_video";
private const string _HairMaskStreamName = "hair_mask";
private OutputStream _hairMaskStream;
private OutputStream<ImageFrame> _hairMaskStream;

public override void StartRun(ImageSource imageSource)
{
Expand Down Expand Up @@ -71,7 +71,7 @@ protected override IList<WaitForResult> RequestDependentAssets()

protected override void ConfigureCalculatorGraph(CalculatorGraphConfig config)
{
_hairMaskStream = new OutputStream(calculatorGraph, _HairMaskStreamName, true);
_hairMaskStream = new OutputStream<ImageFrame>(calculatorGraph, _HairMaskStreamName, true);
calculatorGraph.Initialize(config);
}

Expand Down
Loading

0 comments on commit 4f3668e

Please sign in to comment.