Skip to content

Commit

Permalink
perf: reduce allocations for FaceDetector (#1082)
Browse files Browse the repository at this point in the history
* perf: avoid/reduce allocations

* perf: optimize AssertStatusOk

* perf: use ReadOnlySpan to deserialize proto message

* perf: use stack to build a proto Packet

* perf: avoid Reflection

* refactor: DetectionResult.Empty

* refactor: reuse NormalizedRect

* perf: deserialize as List<Detection> faster

* perf: avoid allocation when building DetectionResult

* refactor: remove unused using statements

* perf: reduce allocation on IMAGE mode

* perf: reduce allocation on VIDEO mode

* remove unnecessary directives
  • Loading branch information
homuler committed Jan 1, 2024
1 parent 0a302b2 commit deeb9f0
Show file tree
Hide file tree
Showing 22 changed files with 480 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

using System.Collections;
using UnityEngine;

using UnityEngine.Rendering;
using FaceDetectionResult = Mediapipe.Tasks.Components.Containers.DetectionResult;

namespace Mediapipe.Unity.Sample.FaceDetection
Expand Down Expand Up @@ -63,6 +63,10 @@ protected override IEnumerator Run()
var flipVertically = transformationOptions.flipVertically;
var imageProcessingOptions = new Tasks.Vision.Core.ImageProcessingOptions(rotationDegrees: (int)transformationOptions.rotationAngle);

AsyncGPUReadbackRequest req = default;
var waitUntilReqDone = new WaitUntil(() => req.done);
var result = FaceDetectionResult.Alloc(options.numFaces);

while (true)
{
if (isPaused)
Expand All @@ -77,8 +81,8 @@ protected override IEnumerator Run()
}

// Copy current image to TextureFrame
var req = textureFrame.ReadTextureAsync(imageSource.GetCurrentTexture(), flipHorizontally, flipVertically);
yield return new WaitUntil(() => req.done);
req = textureFrame.ReadTextureAsync(imageSource.GetCurrentTexture(), flipHorizontally, flipVertically);
yield return waitUntilReqDone;

if (req.hasError)
{
Expand All @@ -90,12 +94,26 @@ protected override IEnumerator Run()
switch (taskApi.runningMode)
{
case Tasks.Vision.Core.RunningMode.IMAGE:
var result = taskApi.Detect(image, imageProcessingOptions);
_detectionResultAnnotationController.DrawNow(result);
if (taskApi.TryDetect(image, imageProcessingOptions, ref result))
{
_detectionResultAnnotationController.DrawNow(result);
}
else
{
// clear the annotation
_detectionResultAnnotationController.DrawNow(FaceDetectionResult.Empty);
}
break;
case Tasks.Vision.Core.RunningMode.VIDEO:
result = taskApi.DetectForVideo(image, (int)GetCurrentTimestampMillisec(), imageProcessingOptions);
_detectionResultAnnotationController.DrawNow(result);
if (taskApi.TryDetectForVideo(image, (int)GetCurrentTimestampMillisec(), imageProcessingOptions, ref result))
{
_detectionResultAnnotationController.DrawNow(result);
}
else
{
// clear the annotation
_detectionResultAnnotationController.DrawNow(FaceDetectionResult.Empty);
}
break;
case Tasks.Vision.Core.RunningMode.LIVE_STREAM:
taskApi.DetectAsync(image, (int)GetCurrentTimestampMillisec(), imageProcessingOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,22 @@ protected static string MarshalStringFromNative(IntPtr strPtr)
return str;
}

/// <summary>
/// The optimized implementation of <see cref="Status.AssertOk" />.
/// </summary>
protected static void AssertStatusOk(IntPtr statusPtr)
{
using (var status = new Status(statusPtr, true))
var ok = SafeNativeMethods.absl_Status__ok(statusPtr);
if (!ok)
{
using (var status = new Status(statusPtr, true))
{
status.AssertOk();
}
}
else
{
status.AssertOk();
UnsafeNativeMethods.absl_Status__delete(statusPtr);
}
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2023 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

namespace Mediapipe
{
public static class ProtoMessageClearExtension
{
public static void Clear(this Detection detection)
{
detection.Label.Clear();
detection.LabelId.Clear();
detection.Score.Clear();

detection.LocationData.ClearFormat();
if (detection.LocationData.BoundingBox != null)
{
detection.LocationData.BoundingBox.ClearXmin();
detection.LocationData.BoundingBox.ClearYmin();
detection.LocationData.BoundingBox.ClearWidth();
detection.LocationData.BoundingBox.ClearHeight();
}
if (detection.LocationData.RelativeBoundingBox != null)
{
detection.LocationData.RelativeBoundingBox.ClearXmin();
detection.LocationData.RelativeBoundingBox.ClearYmin();
detection.LocationData.RelativeBoundingBox.ClearWidth();
detection.LocationData.RelativeBoundingBox.ClearHeight();
}
if (detection.LocationData.Mask != null)
{
detection.LocationData.Mask.ClearWidth();
detection.LocationData.Mask.ClearHeight();
detection.LocationData.Mask.Rasterization.Interval.Clear();
}
detection.LocationData.RelativeKeypoints.Clear();

detection.ClearFeatureTag();
detection.ClearTrackId();
detection.ClearDetectionId();
detection.AssociatedDetections.Clear();
detection.DisplayName.Clear();
detection.ClearTimestampUsec();
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

using System;
using System.Runtime.InteropServices;

using pb = Google.Protobuf;
using Google.Protobuf;

namespace Mediapipe
{
Expand All @@ -22,11 +21,22 @@ public void Dispose()
UnsafeNativeMethods.delete_array__PKc(_str);
}

public T Deserialize<T>(pb::MessageParser<T> parser) where T : pb::IMessage<T>
public T Deserialize<T>(MessageParser<T> parser) where T : IMessage<T>
{
unsafe
{
var bytes = new ReadOnlySpan<byte>((byte*)_str, _length);
return parser.ParseFrom(bytes);
}
}

public void WriteTo<T>(T proto) where T : IMessage<T>
{
var bytes = new byte[_length];
Marshal.Copy(_str, bytes, 0, bytes.Length);
return parser.ParseFrom(bytes);
unsafe
{
var bytes = new ReadOnlySpan<byte>((byte*)_str, _length);
proto.MergeFrom(bytes);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,43 @@ public List<T> Deserialize<T>(pb::MessageParser<T> parser) where T : pb::IMessag
public void Deserialize<T>(pb::MessageParser<T> parser, List<T> protos) where T : pb::IMessage<T>
{
protos.Clear();
_ = WriteTo(parser, protos);
}

/// <summary>
/// Deserializes the data as a list of <typeparamref name="T" />.
/// </summary>
/// <remarks>
/// The deserialized data will be merged into <paramref name="protos" />.
/// You may want to clear each field of <typeparamref name="T"/> before calling this method.
/// If <see cref="_size"/> is less than <paramref name="protos" />.Count, the superfluous elements in <paramref name="protos" /> will be untouched.
/// </remarks>
/// <param name="protos">A list of <typeparamref name="T" /> to populate</param>
/// <returns>
/// The number of overwritten elements in <paramref name="protos" />.
/// </returns>
public int WriteTo<T>(pb::MessageParser<T> parser, List<T> protos) where T : pb::IMessage<T>
{
unsafe
{
var protoPtr = (SerializedProto*)_data;

for (var i = 0; i < _size; i++)
// overwrite the existing list
var len = Math.Min(_size, protos.Count);
for (var i = 0; i < len; i++)
{
var serializedProto = Marshal.PtrToStructure<SerializedProto>((IntPtr)protoPtr++);
serializedProto.WriteTo(protos[i]);
}

for (var i = protos.Count; i < _size; i++)
{
var serializedProto = Marshal.PtrToStructure<SerializedProto>((IntPtr)protoPtr++);
protos.Add(serializedProto.Deserialize(parser));
}
}

return _size;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public unsafe Image(ImageFormat.Types.Format format, int width, int height, int
/// </list>
/// </remarks>
public Image(ImageFormat.Types.Format format, int width, int height, int widthStep, NativeArray<byte> pixelData)
: this(format, width, height, widthStep, pixelData, ImageFrame.VoidDeleter)
: this(format, width, height, widthStep, pixelData, _VoidDeleter)
{ }

#if UNITY_EDITOR_LINUX || UNITY_STANDLONE_LINUX || UNITY_ANDROID
Expand All @@ -55,6 +55,8 @@ public Image(uint name, int width, int height, GpuBufferFormat format, GlTexture
{ }
#endif

private static readonly ImageFrame.Deleter _VoidDeleter = ImageFrame.VoidDeleter;

protected override void DeleteMpPtr()
{
UnsafeNativeMethods.mp_Image__delete(ptr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,16 @@ public unsafe ImageFrame(ImageFormat.Types.Format format, int width, int height,
/// </list>
/// </remarks>
public ImageFrame(ImageFormat.Types.Format format, int width, int height, int widthStep, NativeArray<byte> pixelData)
: this(format, width, height, widthStep, pixelData, VoidDeleter)
: this(format, width, height, widthStep, pixelData, _VoidDeleter)
{ }

protected override void DeleteMpPtr()
{
UnsafeNativeMethods.mp_ImageFrame__delete(ptr);
}

private static readonly Deleter _VoidDeleter = VoidDeleter;

[AOT.MonoPInvokeCallback(typeof(Deleter))]
internal static void VoidDeleter(IntPtr _) { }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Mediapipe
{
public class Packet : MpResourceHandle
{
private Packet(IntPtr ptr, bool isOwner) : base(ptr, isOwner) { }
internal Packet(IntPtr ptr, bool isOwner) : base(ptr, isOwner) { }

protected override void DeleteMpPtr()
{
Expand All @@ -28,6 +28,8 @@ public long TimestampMicroseconds()
return value;
}

public bool IsEmpty() => SafeNativeMethods.mp_Packet__IsEmpty(mpPtr);

internal static Packet CreateEmpty()
{
UnsafeNativeMethods.mp_Packet__(out var ptr).Assert();
Expand Down Expand Up @@ -262,11 +264,17 @@ public static Packet CreateIntAt(int value, long timestampMicrosec)
/// </summary>
public static Packet CreateProto<T>(T value) where T : IMessage<T>
{
var arr = value.ToByteArray();
UnsafeNativeMethods.mp__PacketFromDynamicProto__PKc_PKc_i(value.Descriptor.FullName, arr, arr.Length, out var statusPtr, out var ptr).Assert();
unsafe
{
var size = value.CalculateSize();
var arr = stackalloc byte[size];
value.WriteTo(new Span<byte>(arr, size));

AssertStatusOk(statusPtr);
return new Packet(ptr, true);
UnsafeNativeMethods.mp__PacketFromDynamicProto__PKc_PKc_i(value.Descriptor.FullName, arr, size, out var statusPtr, out var ptr).Assert();

AssertStatusOk(statusPtr);
return new Packet(ptr, true);
}
}

/// <summary>
Expand All @@ -277,11 +285,17 @@ public static Packet CreateProto<T>(T value) where T : IMessage<T>
/// </param>
public static Packet CreateProtoAt<T>(T value, long timestampMicrosec) where T : IMessage<T>
{
var arr = value.ToByteArray();
UnsafeNativeMethods.mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(value.Descriptor.FullName, arr, arr.Length, timestampMicrosec, out var statusPtr, out var ptr).Assert();
unsafe
{
var size = value.CalculateSize();
var arr = stackalloc byte[size];
value.WriteTo(new Span<byte>(arr, size));

AssertStatusOk(statusPtr);
return new Packet(ptr, true);
UnsafeNativeMethods.mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(value.Descriptor.FullName, arr, size, timestampMicrosec, out var statusPtr, out var ptr).Assert();
AssertStatusOk(statusPtr);

return new Packet(ptr, true);
}
}

/// <summary>
Expand Down Expand Up @@ -548,6 +562,22 @@ public void GetProtoList<T>(MessageParser<T> parser, List<T> value) where T : IM
serializedProtoVector.Dispose();
}

public void GetDetectionList(List<Detection> detections)
{
UnsafeNativeMethods.mp_Packet__GetVectorOfProtoMessageLite(mpPtr, out var serializedProtoVector).Assert();

GC.KeepAlive(this);

foreach (var detection in detections)
{
detection.Clear();
}
var size = serializedProtoVector.WriteTo(Detection.Parser, detections);
serializedProtoVector.Dispose();

detections.RemoveRange(size, detections.Count - size);
}

/// <summary>
/// Validate if the content of the <see cref="Packet"/> is a boolean.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,36 @@ protected override void DeleteMpPtr()
return Packet<TValue>.Create<TPacket>(packetPtr, true);
}

/// <remarks>
/// This method cannot verify that the packet type corresponding to the <paramref name="key" /> is indeed a <typeparamref name="TPacket" />,
/// so you must make sure by youreself that it is.
/// </remarks>
public Packet At(string key)
{
UnsafeNativeMethods.mp_PacketMap__find__PKc(mpPtr, key, out var packetPtr).Assert();

if (packetPtr == IntPtr.Zero)
{
return default; // null
}
GC.KeepAlive(this);
return new Packet(packetPtr, true);
}

public void Emplace<T>(string key, Packet<T> packet)
{
UnsafeNativeMethods.mp_PacketMap__emplace__PKc_Rp(mpPtr, key, packet.mpPtr).Assert();
packet.Dispose(); // respect move semantics
GC.KeepAlive(this);
}

public void Emplace(string key, Packet packet)
{
UnsafeNativeMethods.mp_PacketMap__emplace__PKc_Rp(mpPtr, key, packet.mpPtr).Assert();
packet.Dispose(); // respect move semantics
GC.KeepAlive(this);
}

public int Erase(string key)
{
UnsafeNativeMethods.mp_PacketMap__erase__PKc(mpPtr, key, out var count).Assert();
Expand Down
Loading

0 comments on commit deeb9f0

Please sign in to comment.