Skip to content

Commit

Permalink
feat: implement Packet.CreateImage, Packet.GetImage (#1072)
Browse files Browse the repository at this point in the history
  • Loading branch information
homuler committed Dec 17, 2023
1 parent 5e0449d commit 6e14fda
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,31 @@ public static Packet CreateFloatVectorAt(float[] value, long timestampMicrosec)
return new Packet(ptr, true);
}

/// <summary>
/// Create an Image Packet.
/// </summary>
public static Packet CreateImage(Image value)
{
UnsafeNativeMethods.mp__MakeImagePacket__PI(value.mpPtr, out var ptr).Assert();
value.Dispose(); // respect move semantics

return new Packet(ptr, true);
}

/// <summary>
/// Create an Image Packet.
/// </summary>
/// <param name="timestampMicrosec">
/// The timestamp of the packet.
/// </param>
public static Packet CreateImageAt(Image value, long timestampMicrosec)
{
UnsafeNativeMethods.mp__MakeImagePacket_At__PI_ll(value.mpPtr, timestampMicrosec, out var ptr).Assert();
value.Dispose(); // respect move semantics

return new Packet(ptr, true);
}

/// <summary>
/// Get the content of the <see cref="Packet"/> as a boolean.
/// </summary>
Expand Down Expand Up @@ -301,6 +326,12 @@ public void GetFloatArray(float[] value)
/// <summary>
/// Get the content of a float vector Packet as a <see cref="List{float}"/>.
/// </summary>
/// <remarks>
/// On some platforms (e.g. Windows), it will abort the process when <see cref="MediaPipeException"/> should be thrown.
/// </remarks>
/// <exception cref="MediaPipeException">
/// If the <see cref="Packet"/> doesn't contain std::vector&lt;float&gt; data.
/// </exception>
public List<float> GetFloatList()
{
var value = new List<float>();
Expand Down Expand Up @@ -330,6 +361,23 @@ public void GetFloatList(List<float> value)
structArray.Dispose();
}

/// <summary>
/// Get the content of the <see cref="Packet"/> as an <see cref="Image"/>.
/// </summary>
/// <remarks>
/// On some platforms (e.g. Windows), it will abort the process when <see cref="MediaPipeException"/> should be thrown.
/// </remarks>
/// <exception cref="MediaPipeException">
/// If the <see cref="Packet"/> doesn't contain <see cref="Image"/>.
/// </exception>
public Image GetImage()
{
UnsafeNativeMethods.mp_Packet__GetImage(mpPtr, out var ptr).Assert();

GC.KeepAlive(this);
return new Image(ptr, false);
}

/// <summary>
/// Validate if the content of the <see cref="Packet"/> is a boolean.
/// </summary>
Expand Down Expand Up @@ -413,5 +461,19 @@ public void ValidateAsFloatVector()
GC.KeepAlive(this);
AssertStatusOk(statusPtr);
}

/// <summary>
/// Validate if the content of the <see cref="Packet"/> is an <see cref="Image"/> .
/// </summary>
/// <exception cref="BadStatusException">
/// If the <see cref="Packet"/> doesn't contain <see cref="Image"/> .
/// </exception>
public void ValidateAsImage()
{
UnsafeNativeMethods.mp_Packet__ValidateAsImage(mpPtr, out var statusPtr).Assert();

GC.KeepAlive(this);
AssertStatusOk(statusPtr);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ public class ImagePacket : Packet<Image>

public ImagePacket(Image image) : base()
{
UnsafeNativeMethods.mp__MakeImagePacket__Pif(image.mpPtr, out var ptr).Assert();
UnsafeNativeMethods.mp__MakeImagePacket__PI(image.mpPtr, out var ptr).Assert();
image.Dispose(); // respect move semantics

this.ptr = ptr;
}

public ImagePacket(Image image, Timestamp timestamp) : base()
{
UnsafeNativeMethods.mp__MakeImagePacket_At__Pif_Rt(image.mpPtr, timestamp.mpPtr, out var ptr).Assert();
UnsafeNativeMethods.mp__MakeImagePacket_At__PI_Rt(image.mpPtr, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
image.Dispose(); // respect move semantics

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ internal static partial class UnsafeNativeMethods

#region Packet
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeImagePacket__Pif(IntPtr image, out IntPtr packet);
public static extern MpReturnCode mp__MakeImagePacket__PI(IntPtr image, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeImagePacket_At__Pif_Rt(IntPtr image, IntPtr timestamp, out IntPtr packet);
public static extern MpReturnCode mp__MakeImagePacket_At__PI_Rt(IntPtr image, IntPtr timestamp, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeImagePacket_At__PI_ll(IntPtr image, long timestampMicrosec, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ConsumeImage(IntPtr packet, out IntPtr status, out IntPtr image);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System.Linq;
using System.Runtime.InteropServices;
using NUnit.Framework;
using Unity.Collections;

namespace Mediapipe.Tests
{
Expand Down Expand Up @@ -212,6 +215,44 @@ public void CreateFloatVectorAt_ShouldReturnNewFloatListPacket()
}
#endregion

#region Image
[Test]
public void CreateImage_ShouldReturnNewImagePacket()
{
var bytes = Enumerable.Range(0, 32).Select(x => (byte)x).ToArray();
var image = BuildSRGBAImage(bytes, 4, 2);
using var packet = Packet.CreateImage(image);

Assert.DoesNotThrow(packet.ValidateAsImage);

using (var result = packet.GetImage())
{
AssertImage(result, 4, 2, ImageFormat.Types.Format.Srgba, bytes);
}

using var unsetTimestamp = Timestamp.Unset();
Assert.AreEqual(unsetTimestamp.Microseconds(), packet.TimestampMicroseconds());
}

[Test]
public void CreateImageAt_ShouldReturnNewImagePacket()
{
var bytes = Enumerable.Range(0, 32).Select(x => (byte)x).ToArray();
var timestamp = 1;
var image = BuildSRGBAImage(bytes, 4, 2);
using var packet = Packet.CreateImageAt(image, timestamp);

Assert.DoesNotThrow(packet.ValidateAsImage);

using (var result = packet.GetImage())
{
AssertImage(result, 4, 2, ImageFormat.Types.Format.Srgba, bytes);
}

Assert.AreEqual(timestamp, packet.TimestampMicroseconds());
}
#endregion

#region #Validate
[Test]
public void ValidateAsBool_ShouldThrow_When_ValueIsNotSet()
Expand Down Expand Up @@ -247,6 +288,45 @@ public void ValidateAsFloatArray_ShouldThrow_When_ValueIsNotSet()
using var packet = Packet.CreateEmpty();
_ = Assert.Throws<BadStatusException>(packet.ValidateAsFloatArray);
}

[Test]
public void ValidateAsFloatVector_ShouldThrow_When_ValueIsNotSet()
{
using var packet = Packet.CreateEmpty();
_ = Assert.Throws<BadStatusException>(packet.ValidateAsFloatVector);
}

[Test]
public void ValidateAsImage_ShouldThrow_When_ValueIsNotSet()
{
using var packet = Packet.CreateEmpty();
_ = Assert.Throws<BadStatusException>(packet.ValidateAsImage);
}
#endregion

private Image BuildSRGBAImage(byte[] bytes, int width, int height)
{
Assert.AreEqual(bytes.Length, width * height * 4);

var pixelData = new NativeArray<byte>(bytes.Length, Allocator.Temp, NativeArrayOptions.UninitializedMemory);
pixelData.CopyFrom(bytes);

return new Image(ImageFormat.Types.Format.Srgba, width, height, width * 4, pixelData);
}

private void AssertImage(Image image, int width, int height, ImageFormat.Types.Format format, byte[] expectedBytes)
{
Assert.AreEqual(width, image.Width());
Assert.AreEqual(height, image.Height());
Assert.AreEqual(format, image.ImageFormat());

using (var pixelLock = new PixelWriteLock(image))
{
var pixelData = new byte[width * height * ImageFrame.NumberOfChannelsForFormat(format)];
Marshal.Copy(pixelLock.Pixels(), pixelData, 0, pixelData.Length);

Assert.AreEqual(expectedBytes, pixelData);
}
}
}
}
11 changes: 9 additions & 2 deletions mediapipe_api/framework/formats/image.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,27 @@ uint8* mp_PixelWriteLock__Pixels(mediapipe::PixelWriteLock* pixel_read_lock) {
}

// Packet API
MpReturnCode mp__MakeImagePacket__Pif(mediapipe::Image* image, mediapipe::Packet** packet_out) {
MpReturnCode mp__MakeImagePacket__PI(mediapipe::Image* image, mediapipe::Packet** packet_out) {
TRY_ALL
*packet_out = new mediapipe::Packet{mediapipe::MakePacket<mediapipe::Image>(std::move(*image))};
RETURN_CODE(MpReturnCode::Success);
CATCH_ALL
}

MpReturnCode mp__MakeImagePacket_At__Pif_Rt(mediapipe::Image* image, mediapipe::Timestamp* timestamp, mediapipe::Packet** packet_out) {
MpReturnCode mp__MakeImagePacket_At__PI_Rt(mediapipe::Image* image, mediapipe::Timestamp* timestamp, mediapipe::Packet** packet_out) {
TRY_ALL
*packet_out = new mediapipe::Packet{mediapipe::MakePacket<mediapipe::Image>(std::move(*image)).At(*timestamp)};
RETURN_CODE(MpReturnCode::Success);
CATCH_ALL
}

MpReturnCode mp__MakeImagePacket_At__PI_ll(mediapipe::Image* image, int64 timestampMicrosec, mediapipe::Packet** packet_out) {
TRY_ALL
*packet_out = new mediapipe::Packet{mediapipe::MakePacket<mediapipe::Image>(std::move(*image)).At(mediapipe::Timestamp(timestampMicrosec))};
RETURN_CODE(MpReturnCode::Success);
CATCH_ALL
}

MpReturnCode mp_Packet__ConsumeImage(mediapipe::Packet* packet, absl::Status **status_out, mediapipe::Image** value_out) {
return mp_Packet__Consume(packet, status_out, value_out);
}
Expand Down
5 changes: 3 additions & 2 deletions mediapipe_api/framework/formats/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ MP_CAPI(void) mp_PixelWriteLock__delete(mediapipe::PixelWriteLock* pixel_Write_l
MP_CAPI(uint8*) mp_PixelWriteLock__Pixels(mediapipe::PixelWriteLock* pixel_read_lock);

// Packet API
MP_CAPI(MpReturnCode) mp__MakeImagePacket__Pif(mediapipe::Image* image, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp__MakeImagePacket_At__Pif_Rt(mediapipe::Image* image, mediapipe::Timestamp* timestamp, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp__MakeImagePacket__PI(mediapipe::Image* image, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp__MakeImagePacket_At__PI_Rt(mediapipe::Image* image, mediapipe::Timestamp* timestamp, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp__MakeImagePacket_At__PI_ll(mediapipe::Image* image, int64 timestampMicrosec, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp_Packet__ConsumeImage(mediapipe::Packet* packet, absl::Status **status_out, mediapipe::Image** value_out);
MP_CAPI(MpReturnCode) mp_Packet__GetImage(mediapipe::Packet* packet, const mediapipe::Image** value_out);
MP_CAPI(MpReturnCode) mp_Packet__GetImageVector(mediapipe::Packet* packet, mp_api::StructArray<mediapipe::Image*>* value_out);
Expand Down

0 comments on commit 6e14fda

Please sign in to comment.