Skip to content

Commit

Permalink
feat: implement Packet.CreateProto, Packet.GetProto (#1079)
Browse files Browse the repository at this point in the history
  • Loading branch information
homuler committed Dec 23, 2023
1 parent d495193 commit df48262
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using Google.Protobuf;

namespace Mediapipe
{
Expand Down Expand Up @@ -256,6 +257,33 @@ public static Packet CreateIntAt(int value, long timestampMicrosec)
return new Packet(ptr, true);
}

/// <summary>
/// Create a MediaPipe protobuf message Packet.
/// </summary>
public static Packet CreateProto<T>(T value) where T : IMessage<T>
{
var arr = value.ToByteArray();
UnsafeNativeMethods.mp__PacketFromDynamicProto__PKc_PKc_i(value.Descriptor.FullName, arr, arr.Length, out var statusPtr, out var ptr).Assert();

AssertStatusOk(statusPtr);
return new Packet(ptr, true);
}

/// <summary>
/// Create a MediaPipe protobuf message Packet.
/// </summary>
/// <param name="timestampMicrosec">
/// The timestamp of the packet.
/// </param>
public static Packet CreateProtoAt<T>(T value, long timestampMicrosec) where T : IMessage<T>
{
var arr = value.ToByteArray();
UnsafeNativeMethods.mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(value.Descriptor.FullName, arr, arr.Length, timestampMicrosec, out var statusPtr, out var ptr).Assert();

AssertStatusOk(statusPtr);
return new Packet(ptr, true);
}

/// <summary>
/// Get the content of the <see cref="Packet"/> as a boolean.
/// </summary>
Expand Down Expand Up @@ -460,6 +488,27 @@ public int GetInt()
return value;
}

/// <summary>
/// Get the content of the <see cref="Packet"/> as a proto message.
/// </summary>
/// <remarks>
/// On some platforms (e.g. Windows), it will abort the process when <see cref="MediaPipeException"/> should be thrown.
/// </remarks>
/// <exception cref="MediaPipeException">
/// If the <see cref="Packet"/> doesn't contain proto messages.
/// </exception>
public T GetProto<T>(MessageParser<T> parser) where T : IMessage<T>
{
UnsafeNativeMethods.mp_Packet__GetProto(mpPtr, out var value).Assert();

GC.KeepAlive(this);

var proto = value.Deserialize(parser);
value.Dispose();

return proto;
}

/// <summary>
/// Validate if the content of the <see cref="Packet"/> is a boolean.
/// </summary>
Expand Down Expand Up @@ -585,5 +634,19 @@ public void ValidateAsInt()
GC.KeepAlive(this);
AssertStatusOk(statusPtr);
}

/// <summary>
/// Validate if the content of the <see cref="Packet"/> is a proto message.
/// </summary>
/// <exception cref="BadStatusException">
/// If the <see cref="Packet"/> doesn't contain proto messages.
/// </exception>
public void ValidateAsProtoMessageLite()
{
UnsafeNativeMethods.mp_Packet__ValidateAsProtoMessageLite(mpPtr, out var statusPtr).Assert();

GC.KeepAlive(this);
AssertStatusOk(statusPtr);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ internal static partial class UnsafeNativeMethods
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__At__Rt(IntPtr packet, IntPtr timestamp, out IntPtr newPacket);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsProtoMessageLite(IntPtr packet, out IntPtr status);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__Timestamp(IntPtr packet, out IntPtr timestamp);

Expand Down Expand Up @@ -179,6 +176,22 @@ internal static partial class UnsafeNativeMethods
public static extern MpReturnCode mp_Packet__ValidateAsString(IntPtr packet, out IntPtr status);
#endregion

#region Proto
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__PacketFromDynamicProto__PKc_PKc_i(string typeName, byte[] proto, int size,
out IntPtr status, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(string typeName, byte[] proto, int size, long timestampMicrosec,
out IntPtr status, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetProto(IntPtr packet, out SerializedProto value);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsProtoMessageLite(IntPtr packet, out IntPtr status);
#endregion

#region PacketMap
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_PacketMap__(out IntPtr packetMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,47 @@ public void CreateIntAt_ShouldReturnNewIntPacket(int value)
}
#endregion

#region Proto
[Test]
public void CreateProto_ShouldReturnNewProtoPacket()
{
var value = new NormalizedRect()
{
Rotation = 0,
XCenter = 0.5f,
YCenter = 0.5f,
Width = 1,
Height = 1,
};
using var packet = Packet.CreateProto(value);

Assert.DoesNotThrow(packet.ValidateAsProtoMessageLite);
Assert.AreEqual(value, packet.GetProto(NormalizedRect.Parser));

using var unsetTimestamp = Timestamp.Unset();
Assert.AreEqual(unsetTimestamp.Microseconds(), packet.TimestampMicroseconds());
}

[Test]
public void CreateProtoAt_ShouldReturnNewProtoPacket()
{
var timestamp = 1;
var value = new NormalizedRect()
{
Rotation = 0,
XCenter = 0.5f,
YCenter = 0.5f,
Width = 1,
Height = 1,
};
using var packet = Packet.CreateProtoAt(value, timestamp);

Assert.DoesNotThrow(packet.ValidateAsProtoMessageLite);
Assert.AreEqual(value, packet.GetProto(NormalizedRect.Parser));
Assert.AreEqual(timestamp, packet.TimestampMicroseconds());
}
#endregion

#region #Validate
[Test]
public void ValidateAsBool_ShouldThrow_When_ValueIsNotSet()
Expand Down
51 changes: 44 additions & 7 deletions mediapipe_api/framework/packet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,6 @@ MpReturnCode mp_Packet__At__Rt(mediapipe::Packet* packet, mediapipe::Timestamp*

bool mp_Packet__IsEmpty(mediapipe::Packet* packet) { return packet->IsEmpty(); }

MpReturnCode mp_Packet__ValidateAsProtoMessageLite(mediapipe::Packet* packet, absl::Status** status_out) {
TRY
*status_out = new absl::Status{packet->ValidateAsProtoMessageLite()};
RETURN_CODE(MpReturnCode::Success);
CATCH_EXCEPTION
}

MpReturnCode mp_Packet__Timestamp(mediapipe::Packet* packet, mediapipe::Timestamp** timestamp_out) {
TRY
*timestamp_out = new mediapipe::Timestamp{packet->Timestamp()};
Expand Down Expand Up @@ -378,6 +371,50 @@ MpReturnCode mp_Packet__ValidateAsString(mediapipe::Packet* packet, absl::Status
CATCH_EXCEPTION
}

MpReturnCode mp__PacketFromDynamicProto__PKc_PKc_i(const char* type_name, const char* serialized_proto, int size,
absl::Status** status_out, mediapipe::Packet** packet_out) {
TRY_ALL
auto status_or_packet = mediapipe::packet_internal::PacketFromDynamicProto(type_name, std::string(serialized_proto, size));
*status_out = new absl::Status{status_or_packet.status()};
if (!status_or_packet.ok()) {
*packet_out = nullptr;
} else {
*packet_out = new mediapipe::Packet{status_or_packet.value()};
}
RETURN_CODE(MpReturnCode::Success);
CATCH_ALL
}

MpReturnCode mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(const char* type_name, const char* serialized_proto, int size,
int64 timestampMicrosec,
absl::Status** status_out, mediapipe::Packet** packet_out) {
TRY_ALL
auto status_or_packet = mediapipe::packet_internal::PacketFromDynamicProto(type_name, std::string(serialized_proto, size));
*status_out = new absl::Status{status_or_packet.status()};
if (!status_or_packet.ok()) {
*packet_out = nullptr;
} else {
*packet_out = new mediapipe::Packet{status_or_packet.value().At(mediapipe::Timestamp(timestampMicrosec))};
}
RETURN_CODE(MpReturnCode::Success);
CATCH_ALL
}

MpReturnCode mp_Packet__GetProto(mediapipe::Packet* packet, mp_api::SerializedProto* serialized_proto) {
TRY_ALL
const auto& proto = packet->GetProtoMessageLite();
SerializeProto(proto, serialized_proto);
RETURN_CODE(MpReturnCode::Success);
CATCH_ALL
}

MpReturnCode mp_Packet__ValidateAsProtoMessageLite(mediapipe::Packet* packet, absl::Status** status_out) {
TRY
*status_out = new absl::Status{packet->ValidateAsProtoMessageLite()};
RETURN_CODE(MpReturnCode::Success);
CATCH_EXCEPTION
}

/** PacketMap */
MpReturnCode mp_PacketMap__(PacketMap** packet_map_out) {
TRY
Expand Down
10 changes: 9 additions & 1 deletion mediapipe_api/framework/packet.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ MP_CAPI(MpReturnCode) mp_Packet__(mediapipe::Packet** packet_out);
MP_CAPI(void) mp_Packet__delete(mediapipe::Packet* packet);
MP_CAPI(MpReturnCode) mp_Packet__At__Rt(mediapipe::Packet* packet, mediapipe::Timestamp* timestamp, mediapipe::Packet** packet_out);
MP_CAPI(bool) mp_Packet__IsEmpty(mediapipe::Packet* packet);
MP_CAPI(MpReturnCode) mp_Packet__ValidateAsProtoMessageLite(mediapipe::Packet* packet, absl::Status** status_out);
MP_CAPI(MpReturnCode) mp_Packet__Timestamp(mediapipe::Packet* packet, mediapipe::Timestamp** timestamp_out);
MP_CAPI(int64) mp_Packet__TimestampMicroseconds(mediapipe::Packet* packet);
MP_CAPI(MpReturnCode) mp_Packet__DebugString(mediapipe::Packet* packet, const char** str_out);
Expand Down Expand Up @@ -103,6 +102,15 @@ MP_CAPI(MpReturnCode) mp_Packet__ConsumeString(mediapipe::Packet* packet, absl::
MP_CAPI(MpReturnCode) mp_Packet__ConsumeByteString(mediapipe::Packet* packet, absl::Status** status_out, const char** value_out, int* size_out);
MP_CAPI(MpReturnCode) mp_Packet__ValidateAsString(mediapipe::Packet* packet, absl::Status** status_out);

// proto
MP_CAPI(MpReturnCode) mp__PacketFromDynamicProto__PKc_PKc_i(const char* type_name, const char* serialized_proto, int size,
absl::Status** status_out, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(const char* type_name, const char* serialized_proto, int size,
int64 timestampMicrosec,
absl::Status** status_out, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp_Packet__GetProto(mediapipe::Packet* packet, mp_api::SerializedProto* serialized_proto);
MP_CAPI(MpReturnCode) mp_Packet__ValidateAsProtoMessageLite(mediapipe::Packet* packet, absl::Status** status_out);

/** PacketMap API */
MP_CAPI(MpReturnCode) mp_PacketMap__(PacketMap** packet_map_out);
MP_CAPI(void) mp_PacketMap__delete(PacketMap* packet_map);
Expand Down

0 comments on commit df48262

Please sign in to comment.