Skip to content

Commit

Permalink
feat: StringPacket can contain null bytes
Browse files Browse the repository at this point in the history
- StringPacket constructor receives a string and its length
- *_Rtimestamp -> *_Rt, *_Rpacket -> *_Rp
  • Loading branch information
homuler committed Feb 21, 2021
1 parent 0a4720b commit 1d5240e
Show file tree
Hide file tree
Showing 17 changed files with 149 additions and 49 deletions.
3 changes: 2 additions & 1 deletion Assets/MediaPipe/SDK/Scripts/Core/MpResourceHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ public abstract class MpResourceHandle : DisposableObject, IMpResourceHandle {
protected delegate MpReturnCode StringOutFunc(IntPtr ptr, out IntPtr strPtr);
protected string MarshalStringFromNative(StringOutFunc f) {
f(mpPtr, out var strPtr).Assert();
GC.KeepAlive(this);

var str = Marshal.PtrToStringAnsi(strPtr);
UnsafeNativeMethods.delete_array__PKc(strPtr);

GC.KeepAlive(this);
return str;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Runtime.InteropServices;
using Unity.Collections;
using Unity.Collections.LowLevel.Unsafe;
using UnityEngine;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class BoolPacket : Packet<bool> {
}

public BoolPacket(bool value, Timestamp timestamp) : base() {
UnsafeNativeMethods.mp__MakeBoolPacket_At__b_Rtimestamp(value, timestamp.mpPtr, out var ptr).Assert();
UnsafeNativeMethods.mp__MakeBoolPacket_At__b_Rt(value, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
this.ptr = ptr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class FloatArrayPacket : Packet<float[]> {
}

public FloatArrayPacket(float[] value, Timestamp timestamp) : base() {
UnsafeNativeMethods.mp__MakeFloatArrayPacket_At__Pf_i_Rtimestamp(value, value.Length, timestamp.mpPtr, out var ptr).Assert();
UnsafeNativeMethods.mp__MakeFloatArrayPacket_At__Pf_i_Rt(value, value.Length, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
this.ptr = ptr;
Length = value.Length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class FloatPacket : Packet<float> {
}

public FloatPacket(float value, Timestamp timestamp) : base() {
UnsafeNativeMethods.mp__MakeFloatPacket_At__f_Rtimestamp(value, timestamp.mpPtr, out var ptr).Assert();
UnsafeNativeMethods.mp__MakeFloatPacket_At__f_Rt(value, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
this.ptr = ptr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class ImageFramePacket : Packet<ImageFrame> {
}

public ImageFramePacket(ImageFrame imageFrame, Timestamp timestamp) : base() {
UnsafeNativeMethods.mp__MakeImageFramePacket_At__Pif_Rtimestamp(imageFrame.mpPtr, timestamp.mpPtr, out var ptr).Assert();
UnsafeNativeMethods.mp__MakeImageFramePacket_At__Pif_Rt(imageFrame.mpPtr, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
imageFrame.Dispose(); // respect move semantics

Expand Down
2 changes: 1 addition & 1 deletion Assets/MediaPipe/SDK/Scripts/Framework/Packet/IntPacket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class IntPacket : Packet<int> {
}

public IntPacket(int value, Timestamp timestamp) : base() {
UnsafeNativeMethods.mp__MakeIntPacket_At__i_Rtimestamp(value, timestamp.mpPtr, out var ptr).Assert();
UnsafeNativeMethods.mp__MakeIntPacket_At__i_Rt(value, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
this.ptr = ptr;
}
Expand Down
2 changes: 1 addition & 1 deletion Assets/MediaPipe/SDK/Scripts/Framework/Packet/Packet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public abstract class Packet<T> : MpResourceHandle {
/// <remarks>To avoid copying the value, instantiate the packet with timestamp</remarks>
/// <returns>New packet with the given timestamp and the copied value</returns>
public Packet<T> At(Timestamp timestamp) {
UnsafeNativeMethods.mp_Packet__At__Rtimestamp(mpPtr, timestamp.mpPtr, out var packetPtr).Assert();
UnsafeNativeMethods.mp_Packet__At__Rt(mpPtr, timestamp.mpPtr, out var packetPtr).Assert();

GC.KeepAlive(timestamp);
return (Packet<T>)Activator.CreateInstance(this.GetType(), packetPtr, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class SidePacket : MpResourceHandle {
}

public void Emplace<T>(string key, Packet<T> packet) {
UnsafeNativeMethods.mp_SidePacket__emplace__PKc_Rpacket(mpPtr, key, packet.mpPtr).Assert();
UnsafeNativeMethods.mp_SidePacket__emplace__PKc_Rp(mpPtr, key, packet.mpPtr).Assert();
packet.Dispose(); // respect move semantics
GC.KeepAlive(this);
}
Expand Down
25 changes: 24 additions & 1 deletion Assets/MediaPipe/SDK/Scripts/Framework/Packet/StringPacket.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Runtime.InteropServices;

namespace Mediapipe {
public class StringPacket : Packet<string> {
Expand All @@ -11,8 +12,19 @@ public class StringPacket : Packet<string> {
this.ptr = ptr;
}

public StringPacket(byte[] bytes) : base() {
UnsafeNativeMethods.mp__MakeStringPacket__PKc_i(bytes, bytes.Length, out var ptr).Assert();
this.ptr = ptr;
}

public StringPacket(string value, Timestamp timestamp) : base() {
UnsafeNativeMethods.mp__MakeStringPacket_At__PKc_Rtimestamp(value, timestamp.mpPtr, out var ptr).Assert();
UnsafeNativeMethods.mp__MakeStringPacket_At__PKc_Rt(value, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
this.ptr = ptr;
}

public StringPacket(byte[] bytes, Timestamp timestamp) : base() {
UnsafeNativeMethods.mp__MakeStringPacket_At__PKc_i_Rt(bytes, bytes.Length, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
this.ptr = ptr;
}
Expand All @@ -21,6 +33,17 @@ public class StringPacket : Packet<string> {
return MarshalStringFromNative(UnsafeNativeMethods.mp_Packet__GetString);
}

public byte[] GetByteArray() {
UnsafeNativeMethods.mp_Packet__GetByteString(mpPtr, out var strPtr, out int size);
GC.KeepAlive(this);

var bytes = new byte[size];
Marshal.Copy(strPtr, bytes, 0, size);
UnsafeNativeMethods.delete_array__PKc(strPtr);

return bytes;
}

public override StatusOr<string> Consume() {
throw new NotSupportedException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ internal static partial class UnsafeNativeMethods {
public static extern MpReturnCode mp__MakeImageFramePacket__Pif(IntPtr imageFrame, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeImageFramePacket_At__Pif_Rtimestamp(IntPtr imageFrame, IntPtr timestamp, out IntPtr packet);
public static extern MpReturnCode mp__MakeImageFramePacket_At__Pif_Rt(IntPtr imageFrame, IntPtr timestamp, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ConsumeImageFrame(IntPtr packet, out IntPtr statusOrImageFrame);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ internal static partial class UnsafeNativeMethods {
public static extern void mp_Packet__delete(IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__At__Rtimestamp(IntPtr packet, IntPtr timestamp, out IntPtr newPacket);
public static extern MpReturnCode mp_Packet__At__Rt(IntPtr packet, IntPtr timestamp, out IntPtr newPacket);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsProtoMessageLite(IntPtr packet, out IntPtr status);
Expand All @@ -34,7 +34,7 @@ internal static partial class UnsafeNativeMethods {
public static extern MpReturnCode mp__MakeBoolPacket__b([MarshalAs(UnmanagedType.I1)] bool value, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeBoolPacket_At__b_Rtimestamp([MarshalAs(UnmanagedType.I1)] bool value, IntPtr timestamp, out IntPtr packet);
public static extern MpReturnCode mp__MakeBoolPacket_At__b_Rt([MarshalAs(UnmanagedType.I1)] bool value, IntPtr timestamp, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetBool(IntPtr packet, [MarshalAs(UnmanagedType.I1)]out bool value);
Expand All @@ -48,7 +48,7 @@ internal static partial class UnsafeNativeMethods {
public static extern MpReturnCode mp__MakeFloatPacket__f(float value, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeFloatPacket_At__f_Rtimestamp(float value, IntPtr timestamp, out IntPtr packet);
public static extern MpReturnCode mp__MakeFloatPacket_At__f_Rt(float value, IntPtr timestamp, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetFloat(IntPtr packet, out float value);
Expand All @@ -62,7 +62,7 @@ internal static partial class UnsafeNativeMethods {
public static extern MpReturnCode mp__MakeIntPacket__i(int value, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeIntPacket_At__i_Rtimestamp(int value, IntPtr timestamp, out IntPtr packet);
public static extern MpReturnCode mp__MakeIntPacket_At__i_Rt(int value, IntPtr timestamp, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetInt(IntPtr packet, out int value);
Expand All @@ -76,7 +76,7 @@ internal static partial class UnsafeNativeMethods {
public static extern MpReturnCode mp__MakeFloatArrayPacket__Pf_i(float[] value, int size, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeFloatArrayPacket_At__Pf_i_Rtimestamp(float[] value, int size, IntPtr timestamp, out IntPtr packet);
public static extern MpReturnCode mp__MakeFloatArrayPacket_At__Pf_i_Rt(float[] value, int size, IntPtr timestamp, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetFloatArray(IntPtr packet, out IntPtr value);
Expand All @@ -90,11 +90,20 @@ internal static partial class UnsafeNativeMethods {
public static extern MpReturnCode mp__MakeStringPacket__PKc(string value, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeStringPacket_At__PKc_Rtimestamp(string value, IntPtr timestamp, out IntPtr packet);
public static extern MpReturnCode mp__MakeStringPacket_At__PKc_Rt(string value, IntPtr timestamp, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeStringPacket__PKc_i(byte[] bytes, int size, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeStringPacket_At__PKc_i_Rt(byte[] bytes, int size, IntPtr timestamp, out IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetString(IntPtr packet, out IntPtr value);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetByteString(IntPtr packet, out IntPtr value, out int size);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsString(IntPtr packet, out IntPtr status);
#endregion
Expand All @@ -107,7 +116,7 @@ internal static partial class UnsafeNativeMethods {
public static extern void mp_SidePacket__delete(IntPtr sidePacket);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_SidePacket__emplace__PKc_Rpacket(IntPtr sidePacket, string key, IntPtr packet);
public static extern MpReturnCode mp_SidePacket__emplace__PKc_Rp(IntPtr sidePacket, string key, IntPtr packet);

[DllImport (MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_SidePacket__at__PKc(IntPtr sidePacket, string key, out IntPtr packet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class StringPacketTest {
}

[Test]
public void Ctor_ShouldInstantiatePacket_When_CalledWithValue() {
public void Ctor_ShouldInstantiatePacket_When_CalledWithString() {
var packet = new StringPacket("test");

Assert.True(packet.ValidateAsType().ok);
Expand All @@ -25,14 +25,35 @@ public class StringPacketTest {
}

[Test]
public void Ctor_ShouldInstantiatePacket_When_CalledWithValueAndTimestamp() {
public void Ctor_ShouldInstantiatePacket_When_CalledWithByteArray() {
byte[] bytes = new byte[] { (byte)'t', (byte)'e', (byte)'s', (byte)'t' };
var packet = new StringPacket(bytes);

Assert.True(packet.ValidateAsType().ok);
Assert.AreEqual(packet.Get(), "test");
Assert.AreEqual(packet.Timestamp(), Timestamp.Unset());
}

[Test]
public void Ctor_ShouldInstantiatePacket_When_CalledWithStringAndTimestamp() {
var timestamp = new Timestamp(1);
var packet = new StringPacket("test", timestamp);

Assert.True(packet.ValidateAsType().ok);
Assert.AreEqual(packet.Get(), "test");
Assert.AreEqual(packet.Timestamp(), timestamp);
}

[Test]
public void Ctor_ShouldInstantiatePacket_When_CalledWithByteArrayAndTimestamp() {
var timestamp = new Timestamp(1);
byte[] bytes = new byte[] { (byte)'t', (byte)'e', (byte)'s', (byte)'t' };
var packet = new StringPacket(bytes, timestamp);

Assert.True(packet.ValidateAsType().ok);
Assert.AreEqual(packet.Get(), "test");
Assert.AreEqual(packet.Timestamp(), timestamp);
}
#endregion

#region #isDisposed
Expand All @@ -52,6 +73,17 @@ public class StringPacketTest {
}
#endregion

#region #GetByteArray
[Test]
public void GetByteArray_ShouldReturnByteArray() {
byte[] bytes = new byte[] { (byte)'a', (byte)'b', 0, (byte)'c' };
var packet = new StringPacket(bytes);

Assert.AreEqual(packet.GetByteArray(), bytes);
Assert.AreEqual(packet.Get(), "ab");
}
#endregion

#region #Consume
[Test]
public void Consume_ShouldThrowNotSupportedException() {
Expand Down
6 changes: 3 additions & 3 deletions C/mediapipe_api/framework/formats/image_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ MpReturnCode mp__MakeImageFramePacket__Pif(mediapipe::ImageFrame* image_frame, m
} CATCH_EXCEPTION
}

MpReturnCode mp__MakeImageFramePacket_At__Pif_Rtimestamp(mediapipe::ImageFrame* image_frame,
mediapipe::Timestamp* timestamp,
mediapipe::Packet** packet_out) {
MpReturnCode mp__MakeImageFramePacket_At__Pif_Rt(mediapipe::ImageFrame* image_frame,
mediapipe::Timestamp* timestamp,
mediapipe::Packet** packet_out) {
TRY {
*packet_out = new mediapipe::Packet { mediapipe::MakePacket<mediapipe::ImageFrame>(std::move(*image_frame)).At(*timestamp) };
RETURN_CODE(MpReturnCode::Success);
Expand Down
6 changes: 3 additions & 3 deletions C/mediapipe_api/framework/formats/image_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ MP_CAPI(MpReturnCode) mp_StatusOrImageFrame__ConsumeValueOrDie(StatusOrImageFram

// Packet API
MP_CAPI(MpReturnCode) mp__MakeImageFramePacket__Pif(mediapipe::ImageFrame* image_frame, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp__MakeImageFramePacket_At__Pif_Rtimestamp(mediapipe::ImageFrame* image_frame,
mediapipe::Timestamp* timestamp,
mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp__MakeImageFramePacket_At__Pif_Rt(mediapipe::ImageFrame* image_frame,
mediapipe::Timestamp* timestamp,
mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp_Packet__ConsumeImageFrame(mediapipe::Packet* packet, StatusOrImageFrame** value_out);
MP_CAPI(MpReturnCode) mp_Packet__GetImageFrame(mediapipe::Packet* packet, const mediapipe::ImageFrame** value_out);
MP_CAPI(MpReturnCode) mp_Packet__ValidateAsImageFrame(mediapipe::Packet* packet, mediapipe::Status** status_out);
Expand Down
Loading

0 comments on commit 1d5240e

Please sign in to comment.