Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: strongly-typed Packet, once again #1120

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 152 additions & 10 deletions Assets/MediaPipeUnity/Samples/Common/Scripts/GraphRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using UnityEngine;

using Stopwatch = System.Diagnostics.Stopwatch;
Expand Down Expand Up @@ -179,7 +180,7 @@ public virtual void Stop()
}
}

protected void AddPacketToInputStream(string streamName, Packet packet)
protected void AddPacketToInputStream<T>(string streamName, Packet<T> packet)
{
calculatorGraph.AddPacketToInputStream(streamName, packet);
}
Expand All @@ -201,18 +202,12 @@ protected void AddTextureFrameToInputStream(string streamName, TextureFrame text
AddPacketToInputStream(streamName, Packet.CreateImageFrameAt(imageFrame, latestTimestamp));
}

protected void AssertResult(params OutputStream.NextResult[] results)
protected bool TryGetValue<T>(Packet<T> packet, out T value)
{
foreach (var result in results)
{
if (!result.ok)
{
throw new Exception("Failed to get the next packet");
}
}
return TryGetValue(packet, out value, (packet) => packet.Get());
}

protected bool TryGetValue<T>(Packet packet, out T value, Func<Packet, T> getter)
protected bool TryGetValue<T>(Packet<T> packet, out T value, Func<Packet<T>, T> getter)
{
if (packet == null)
{
Expand All @@ -223,6 +218,153 @@ protected bool TryGetValue<T>(Packet packet, out T value, Func<Packet, T> getter
return true;
}

protected void AssertResult<T>(OutputStream<T>.NextResult result)
{
if (!result.ok)
{
throw new Exception("Failed to get the next packet");
}
}

protected void AssertResult<T1, T2>((OutputStream<T1>.NextResult, OutputStream<T2>.NextResult) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
}

protected void AssertResult<T1, T2, T3>((OutputStream<T1>.NextResult, OutputStream<T2>.NextResult, OutputStream<T3>.NextResult) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
}

protected void AssertResult<T1, T2, T3, T4>((OutputStream<T1>.NextResult, OutputStream<T2>.NextResult, OutputStream<T3>.NextResult, OutputStream<T4>.NextResult) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
AssertResult(result.Item4);
}

protected void AssertResult<T1, T2, T3, T4, T5>(
(
OutputStream<T1>.NextResult,
OutputStream<T2>.NextResult,
OutputStream<T3>.NextResult,
OutputStream<T4>.NextResult,
OutputStream<T5>.NextResult
) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
AssertResult(result.Item4);
AssertResult(result.Item5);
}

protected void AssertResult<T1, T2, T3, T4, T5, T6>(
(
OutputStream<T1>.NextResult,
OutputStream<T2>.NextResult,
OutputStream<T3>.NextResult,
OutputStream<T4>.NextResult,
OutputStream<T5>.NextResult,
OutputStream<T6>.NextResult
) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
AssertResult(result.Item4);
AssertResult(result.Item5);
AssertResult(result.Item6);
}

protected void AssertResult<T1, T2, T3, T4, T5, T6, T7>(
(
OutputStream<T1>.NextResult,
OutputStream<T2>.NextResult,
OutputStream<T3>.NextResult,
OutputStream<T4>.NextResult,
OutputStream<T5>.NextResult,
OutputStream<T6>.NextResult,
OutputStream<T7>.NextResult
) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
AssertResult(result.Item4);
AssertResult(result.Item5);
AssertResult(result.Item6);
AssertResult(result.Item7);
}

protected void AssertResult<T1, T2, T3, T4, T5, T6, T7, T8>(
(
OutputStream<T1>.NextResult,
OutputStream<T2>.NextResult,
OutputStream<T3>.NextResult,
OutputStream<T4>.NextResult,
OutputStream<T5>.NextResult,
OutputStream<T6>.NextResult,
OutputStream<T7>.NextResult,
OutputStream<T8>.NextResult
) result)
{
AssertResult(result.Item1);
AssertResult(result.Item2);
AssertResult(result.Item3);
AssertResult(result.Item4);
AssertResult(result.Item5);
AssertResult(result.Item6);
AssertResult(result.Item7);
AssertResult(result.Item8);
}

protected async Task<(T1, T2)> WhenAll<T1, T2>(Task<T1> task1, Task<T2> task2)
{
await Task.WhenAll(task1, task2);
return (task1.Result, task2.Result);
}

protected async Task<(T1, T2, T3)> WhenAll<T1, T2, T3>(Task<T1> task1, Task<T2> task2, Task<T3> task3)
{
await Task.WhenAll(task1, task2, task3);
return (task1.Result, task2.Result, task3.Result);
}

protected async Task<(T1, T2, T3, T4)> WhenAll<T1, T2, T3, T4>(Task<T1> task1, Task<T2> task2, Task<T3> task3, Task<T4> task4)
{
await Task.WhenAll(task1, task2, task3, task4);
return (task1.Result, task2.Result, task3.Result, task4.Result);
}

protected async Task<(T1, T2, T3, T4, T5)> WhenAll<T1, T2, T3, T4, T5>(Task<T1> task1, Task<T2> task2, Task<T3> task3, Task<T4> task4, Task<T5> task5)
{
await Task.WhenAll(task1, task2, task3, task4, task5);
return (task1.Result, task2.Result, task3.Result, task4.Result, task5.Result);
}

protected async Task<(T1, T2, T3, T4, T5, T6)> WhenAll<T1, T2, T3, T4, T5, T6>(Task<T1> task1, Task<T2> task2, Task<T3> task3, Task<T4> task4, Task<T5> task5, Task<T6> task6)
{
await Task.WhenAll(task1, task2, task3, task4, task5, task6);
return (task1.Result, task2.Result, task3.Result, task4.Result, task5.Result, task6.Result);
}

protected async Task<(T1, T2, T3, T4, T5, T6, T7)> WhenAll<T1, T2, T3, T4, T5, T6, T7>(Task<T1> task1, Task<T2> task2, Task<T3> task3, Task<T4> task4, Task<T5> task5, Task<T6> task6, Task<T7> task7)
{
await Task.WhenAll(task1, task2, task3, task4, task5, task6, task7);
return (task1.Result, task2.Result, task3.Result, task4.Result, task5.Result, task6.Result, task7.Result);
}

protected async Task<(T1, T2, T3, T4, T5, T6, T7, T8)> WhenAll<T1, T2, T3, T4, T5, T6, T7, T8>(Task<T1> task1, Task<T2> task2, Task<T3> task3, Task<T4> task4, Task<T5> task5, Task<T6> task6, Task<T7> task7, Task<T8> task8)
{
await Task.WhenAll(task1, task2, task3, task4, task5, task6, task7, task8);
return (task1.Result, task2.Result, task3.Result, task4.Result, task5.Result, task6.Result, task7.Result, task8.Result);
}

protected long GetCurrentTimestampMicrosec()
{
return _stopwatch == null || !_stopwatch.IsRunning ? -1 : _stopwatch.ElapsedTicks / (TimeSpan.TicksPerMillisecond / 1000);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ public float minDetectionConfidence
set => _minDetectionConfidence = Mathf.Clamp01(value);
}

public event EventHandler<OutputStream.OutputEventArgs> OnFaceDetectionsOutput
public event EventHandler<OutputStream<List<Detection>>.OutputEventArgs> OnFaceDetectionsOutput
{
add => _faceDetectionsStream.AddListener(value, timeoutMicrosec);
remove => _faceDetectionsStream.RemoveListener(value);
}

private const string _InputStreamName = "input_video";
private const string _FaceDetectionsStreamName = "face_detections";
private OutputStream _faceDetectionsStream;
private OutputStream<List<Detection>> _faceDetectionsStream;

public override void StartRun(ImageSource imageSource)
{
Expand Down Expand Up @@ -82,7 +82,7 @@ protected override IList<WaitForResult> RequestDependentAssets()

protected override void ConfigureCalculatorGraph(CalculatorGraphConfig config)
{
_faceDetectionsStream = new OutputStream(calculatorGraph, _FaceDetectionsStreamName, true);
_faceDetectionsStream = new OutputStream<List<Detection>>(calculatorGraph, _FaceDetectionsStreamName, true);
Debug.Log(timeoutMicrosec);

var faceDetectionCalculators = config.Node.Where((node) => node.Calculator.StartsWith("FaceDetection")).ToList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// https://opensource.org/licenses/MIT.

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

namespace Mediapipe.Unity.Sample.FaceDetection
Expand Down Expand Up @@ -48,7 +49,7 @@ protected override IEnumerator WaitForNextValue()
_faceDetectionsAnnotationController.DrawNow(task.Result);
}

private void OnFaceDetectionsOutput(object stream, OutputStream.OutputEventArgs eventArgs)
private void OnFaceDetectionsOutput(object stream, OutputStream<List<Detection>>.OutputEventArgs eventArgs)
{
var packet = eventArgs.packet;
var value = packet == null ? default : packet.GetProtoList(Detection.Parser);
Expand Down
34 changes: 17 additions & 17 deletions Assets/MediaPipeUnity/Samples/Scenes/Face Mesh/FaceMeshGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,25 @@ public float minTrackingConfidence
set => _minTrackingConfidence = Mathf.Clamp01(value);
}

public event EventHandler<OutputStream.OutputEventArgs> OnFaceDetectionsOutput
public event EventHandler<OutputStream<List<Detection>>.OutputEventArgs> OnFaceDetectionsOutput
{
add => _faceDetectionsStream.AddListener(value, timeoutMicrosec);
remove => _faceDetectionsStream.RemoveListener(value);
}

public event EventHandler<OutputStream.OutputEventArgs> OnMultiFaceLandmarksOutput
public event EventHandler<OutputStream<List<NormalizedLandmarkList>>.OutputEventArgs> OnMultiFaceLandmarksOutput
{
add => _multiFaceLandmarksStream.AddListener(value, timeoutMicrosec);
remove => _multiFaceLandmarksStream.RemoveListener(value);
}

public event EventHandler<OutputStream.OutputEventArgs> OnFaceRectsFromLandmarksOutput
public event EventHandler<OutputStream<List<NormalizedRect>>.OutputEventArgs> OnFaceRectsFromLandmarksOutput
{
add => _faceRectsFromLandmarksStream.AddListener(value, timeoutMicrosec);
remove => _faceRectsFromLandmarksStream.RemoveListener(value);
}

public event EventHandler<OutputStream.OutputEventArgs> OnFaceRectsFromDetectionsOutput
public event EventHandler<OutputStream<List<NormalizedRect>>.OutputEventArgs> OnFaceRectsFromDetectionsOutput
{
add => _faceRectsFromDetectionsStream.AddListener(value, timeoutMicrosec);
remove => _faceRectsFromDetectionsStream.RemoveListener(value);
Expand All @@ -80,10 +80,10 @@ public event EventHandler<OutputStream.OutputEventArgs> OnFaceRectsFromDetection
private const string _FaceRectsFromLandmarksStreamName = "face_rects_from_landmarks";
private const string _FaceRectsFromDetectionsStreamName = "face_rects_from_detections";

private OutputStream _faceDetectionsStream;
private OutputStream _multiFaceLandmarksStream;
private OutputStream _faceRectsFromLandmarksStream;
private OutputStream _faceRectsFromDetectionsStream;
private OutputStream<List<Detection>> _faceDetectionsStream;
private OutputStream<List<NormalizedLandmarkList>> _multiFaceLandmarksStream;
private OutputStream<List<NormalizedRect>> _faceRectsFromLandmarksStream;
private OutputStream<List<NormalizedRect>> _faceRectsFromDetectionsStream;

public override void StartRun(ImageSource imageSource)
{
Expand Down Expand Up @@ -117,27 +117,27 @@ public void AddTextureFrameToInputStream(TextureFrame textureFrame)

public async Task<FaceMeshResult> WaitNext()
{
var results = await Task.WhenAll(
var results = await WhenAll(
_faceDetectionsStream.WaitNextAsync(),
_multiFaceLandmarksStream.WaitNextAsync(),
_faceRectsFromLandmarksStream.WaitNextAsync(),
_faceRectsFromDetectionsStream.WaitNextAsync()
);
AssertResult(results);

_ = TryGetValue(results[0].packet, out var faceDetections, (packet) =>
_ = TryGetValue(results.Item1.packet, out var faceDetections, (packet) =>
{
return packet.GetProtoList(Detection.Parser);
});
_ = TryGetValue(results[1].packet, out var multiFaceLandmarks, (packet) =>
_ = TryGetValue(results.Item2.packet, out var multiFaceLandmarks, (packet) =>
{
return packet.GetProtoList(NormalizedLandmarkList.Parser);
});
_ = TryGetValue(results[2].packet, out var faceRectsFromLandmarks, (packet) =>
_ = TryGetValue(results.Item3.packet, out var faceRectsFromLandmarks, (packet) =>
{
return packet.GetProtoList(NormalizedRect.Parser);
});
_ = TryGetValue(results[3].packet, out var faceRectsFromDetections, (packet) =>
_ = TryGetValue(results.Item4.packet, out var faceRectsFromDetections, (packet) =>
{
return packet.GetProtoList(NormalizedRect.Parser);
});
Expand All @@ -147,10 +147,10 @@ public async Task<FaceMeshResult> WaitNext()

protected override void ConfigureCalculatorGraph(CalculatorGraphConfig config)
{
_faceDetectionsStream = new OutputStream(calculatorGraph, _FaceDetectionsStreamName, true);
_multiFaceLandmarksStream = new OutputStream(calculatorGraph, _MultiFaceLandmarksStreamName, true);
_faceRectsFromLandmarksStream = new OutputStream(calculatorGraph, _FaceRectsFromLandmarksStreamName, true);
_faceRectsFromDetectionsStream = new OutputStream(calculatorGraph, _FaceRectsFromDetectionsStreamName, true);
_faceDetectionsStream = new OutputStream<List<Detection>>(calculatorGraph, _FaceDetectionsStreamName, true);
_multiFaceLandmarksStream = new OutputStream<List<NormalizedLandmarkList>>(calculatorGraph, _MultiFaceLandmarksStreamName, true);
_faceRectsFromLandmarksStream = new OutputStream<List<NormalizedRect>>(calculatorGraph, _FaceRectsFromLandmarksStreamName, true);
_faceRectsFromDetectionsStream = new OutputStream<List<NormalizedRect>>(calculatorGraph, _FaceRectsFromDetectionsStreamName, true);

using (var validatedGraphConfig = new ValidatedGraphConfig())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// https://opensource.org/licenses/MIT.

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

namespace Mediapipe.Unity.Sample.FaceMesh
Expand Down Expand Up @@ -74,28 +75,28 @@ protected override IEnumerator WaitForNextValue()
_faceRectsFromDetectionsAnnotationController.DrawNow(result.faceRectsFromDetections);
}

private void OnFaceDetectionsOutput(object stream, OutputStream.OutputEventArgs eventArgs)
private void OnFaceDetectionsOutput(object stream, OutputStream<List<Detection>>.OutputEventArgs eventArgs)
{
var packet = eventArgs.packet;
var value = packet == null ? default : packet.GetProtoList(Detection.Parser);
_faceDetectionsAnnotationController.DrawLater(value);
}

private void OnMultiFaceLandmarksOutput(object stream, OutputStream.OutputEventArgs eventArgs)
private void OnMultiFaceLandmarksOutput(object stream, OutputStream<List<NormalizedLandmarkList>>.OutputEventArgs eventArgs)
{
var packet = eventArgs.packet;
var value = packet == null ? default : packet.GetProtoList(NormalizedLandmarkList.Parser);
_multiFaceLandmarksAnnotationController.DrawLater(value);
}

private void OnFaceRectsFromLandmarksOutput(object stream, OutputStream.OutputEventArgs eventArgs)
private void OnFaceRectsFromLandmarksOutput(object stream, OutputStream<List<NormalizedRect>>.OutputEventArgs eventArgs)
{
var packet = eventArgs.packet;
var value = packet == null ? default : packet.GetProtoList(NormalizedRect.Parser);
_faceRectsFromLandmarksAnnotationController.DrawLater(value);
}

private void OnFaceRectsFromDetectionsOutput(object stream, OutputStream.OutputEventArgs eventArgs)
private void OnFaceRectsFromDetectionsOutput(object stream, OutputStream<List<NormalizedRect>>.OutputEventArgs eventArgs)
{
var packet = eventArgs.packet;
var value = packet == null ? default : packet.GetProtoList(NormalizedRect.Parser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Mediapipe.Unity.Sample.HairSegmentation
{
public class HairSegmentationGraph : GraphRunner
{
public event EventHandler<OutputStream.OutputEventArgs> OnHairMaskOutput
public event EventHandler<OutputStream<ImageFrame>.OutputEventArgs> OnHairMaskOutput
{
add => _hairMaskStream.AddListener(value, timeoutMicrosec);
remove => _hairMaskStream.RemoveListener(value);
Expand All @@ -25,7 +25,7 @@ public event EventHandler<OutputStream.OutputEventArgs> OnHairMaskOutput

private const string _InputStreamName = "input_video";
private const string _HairMaskStreamName = "hair_mask";
private OutputStream _hairMaskStream;
private OutputStream<ImageFrame> _hairMaskStream;

public override void StartRun(ImageSource imageSource)
{
Expand Down Expand Up @@ -71,7 +71,7 @@ protected override IList<WaitForResult> RequestDependentAssets()

protected override void ConfigureCalculatorGraph(CalculatorGraphConfig config)
{
_hairMaskStream = new OutputStream(calculatorGraph, _HairMaskStreamName, true);
_hairMaskStream = new OutputStream<ImageFrame>(calculatorGraph, _HairMaskStreamName, true);
calculatorGraph.Initialize(config);
}

Expand Down
Loading