diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Framework/Packet.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Framework/Packet.cs index 4261ffba0..776e70442 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Framework/Packet.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/Framework/Packet.cs @@ -7,6 +7,7 @@ using System; using System.Collections.Generic; using System.Runtime.InteropServices; +using Google.Protobuf; namespace Mediapipe { @@ -256,6 +257,33 @@ public static Packet CreateIntAt(int value, long timestampMicrosec) return new Packet(ptr, true); } + /// + /// Create a MediaPipe protobuf message Packet. + /// + public static Packet CreateProto(T value) where T : IMessage + { + var arr = value.ToByteArray(); + UnsafeNativeMethods.mp__PacketFromDynamicProto__PKc_PKc_i(value.Descriptor.FullName, arr, arr.Length, out var statusPtr, out var ptr).Assert(); + + AssertStatusOk(statusPtr); + return new Packet(ptr, true); + } + + /// + /// Create a MediaPipe protobuf message Packet. + /// + /// + /// The timestamp of the packet. + /// + public static Packet CreateProtoAt(T value, long timestampMicrosec) where T : IMessage + { + var arr = value.ToByteArray(); + UnsafeNativeMethods.mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(value.Descriptor.FullName, arr, arr.Length, timestampMicrosec, out var statusPtr, out var ptr).Assert(); + + AssertStatusOk(statusPtr); + return new Packet(ptr, true); + } + /// /// Get the content of the as a boolean. /// @@ -460,6 +488,27 @@ public int GetInt() return value; } + /// + /// Get the content of the as a proto message. + /// + /// + /// On some platforms (e.g. Windows), it will abort the process when should be thrown. + /// + /// + /// If the doesn't contain proto messages. + /// + public T GetProto(MessageParser parser) where T : IMessage + { + UnsafeNativeMethods.mp_Packet__GetProto(mpPtr, out var value).Assert(); + + GC.KeepAlive(this); + + var proto = value.Deserialize(parser); + value.Dispose(); + + return proto; + } + /// /// Validate if the content of the is a boolean. /// @@ -585,5 +634,19 @@ public void ValidateAsInt() GC.KeepAlive(this); AssertStatusOk(statusPtr); } + + /// + /// Validate if the content of the is a proto message. + /// + /// + /// If the doesn't contain proto messages. + /// + public void ValidateAsProtoMessageLite() + { + UnsafeNativeMethods.mp_Packet__ValidateAsProtoMessageLite(mpPtr, out var statusPtr).Assert(); + + GC.KeepAlive(this); + AssertStatusOk(statusPtr); + } } } diff --git a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Framework/Packet_Unsafe.cs b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Framework/Packet_Unsafe.cs index 350eaddda..584d0088a 100644 --- a/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Framework/Packet_Unsafe.cs +++ b/Packages/com.github.homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Framework/Packet_Unsafe.cs @@ -21,9 +21,6 @@ internal static partial class UnsafeNativeMethods [DllImport(MediaPipeLibrary, ExactSpelling = true)] public static extern MpReturnCode mp_Packet__At__Rt(IntPtr packet, IntPtr timestamp, out IntPtr newPacket); - [DllImport(MediaPipeLibrary, ExactSpelling = true)] - public static extern MpReturnCode mp_Packet__ValidateAsProtoMessageLite(IntPtr packet, out IntPtr status); - [DllImport(MediaPipeLibrary, ExactSpelling = true)] public static extern MpReturnCode mp_Packet__Timestamp(IntPtr packet, out IntPtr timestamp); @@ -179,6 +176,22 @@ internal static partial class UnsafeNativeMethods public static extern MpReturnCode mp_Packet__ValidateAsString(IntPtr packet, out IntPtr status); #endregion + #region Proto + [DllImport(MediaPipeLibrary, ExactSpelling = true)] + public static extern MpReturnCode mp__PacketFromDynamicProto__PKc_PKc_i(string typeName, byte[] proto, int size, + out IntPtr status, out IntPtr packet); + + [DllImport(MediaPipeLibrary, ExactSpelling = true)] + public static extern MpReturnCode mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(string typeName, byte[] proto, int size, long timestampMicrosec, + out IntPtr status, out IntPtr packet); + + [DllImport(MediaPipeLibrary, ExactSpelling = true)] + public static extern MpReturnCode mp_Packet__GetProto(IntPtr packet, out SerializedProto value); + + [DllImport(MediaPipeLibrary, ExactSpelling = true)] + public static extern MpReturnCode mp_Packet__ValidateAsProtoMessageLite(IntPtr packet, out IntPtr status); + #endregion + #region PacketMap [DllImport(MediaPipeLibrary, ExactSpelling = true)] public static extern MpReturnCode mp_PacketMap__(out IntPtr packetMap); diff --git a/Packages/com.github.homuler.mediapipe/Tests/EditMode/Framework/PacketTest.cs b/Packages/com.github.homuler.mediapipe/Tests/EditMode/Framework/PacketTest.cs index d12306d90..771f4d21b 100644 --- a/Packages/com.github.homuler.mediapipe/Tests/EditMode/Framework/PacketTest.cs +++ b/Packages/com.github.homuler.mediapipe/Tests/EditMode/Framework/PacketTest.cs @@ -320,6 +320,47 @@ public void CreateIntAt_ShouldReturnNewIntPacket(int value) } #endregion + #region Proto + [Test] + public void CreateProto_ShouldReturnNewProtoPacket() + { + var value = new NormalizedRect() + { + Rotation = 0, + XCenter = 0.5f, + YCenter = 0.5f, + Width = 1, + Height = 1, + }; + using var packet = Packet.CreateProto(value); + + Assert.DoesNotThrow(packet.ValidateAsProtoMessageLite); + Assert.AreEqual(value, packet.GetProto(NormalizedRect.Parser)); + + using var unsetTimestamp = Timestamp.Unset(); + Assert.AreEqual(unsetTimestamp.Microseconds(), packet.TimestampMicroseconds()); + } + + [Test] + public void CreateProtoAt_ShouldReturnNewProtoPacket() + { + var timestamp = 1; + var value = new NormalizedRect() + { + Rotation = 0, + XCenter = 0.5f, + YCenter = 0.5f, + Width = 1, + Height = 1, + }; + using var packet = Packet.CreateProtoAt(value, timestamp); + + Assert.DoesNotThrow(packet.ValidateAsProtoMessageLite); + Assert.AreEqual(value, packet.GetProto(NormalizedRect.Parser)); + Assert.AreEqual(timestamp, packet.TimestampMicroseconds()); + } + #endregion + #region #Validate [Test] public void ValidateAsBool_ShouldThrow_When_ValueIsNotSet() diff --git a/mediapipe_api/framework/packet.cc b/mediapipe_api/framework/packet.cc index e283b6fe5..d812af399 100644 --- a/mediapipe_api/framework/packet.cc +++ b/mediapipe_api/framework/packet.cc @@ -28,13 +28,6 @@ MpReturnCode mp_Packet__At__Rt(mediapipe::Packet* packet, mediapipe::Timestamp* bool mp_Packet__IsEmpty(mediapipe::Packet* packet) { return packet->IsEmpty(); } -MpReturnCode mp_Packet__ValidateAsProtoMessageLite(mediapipe::Packet* packet, absl::Status** status_out) { - TRY - *status_out = new absl::Status{packet->ValidateAsProtoMessageLite()}; - RETURN_CODE(MpReturnCode::Success); - CATCH_EXCEPTION -} - MpReturnCode mp_Packet__Timestamp(mediapipe::Packet* packet, mediapipe::Timestamp** timestamp_out) { TRY *timestamp_out = new mediapipe::Timestamp{packet->Timestamp()}; @@ -378,6 +371,50 @@ MpReturnCode mp_Packet__ValidateAsString(mediapipe::Packet* packet, absl::Status CATCH_EXCEPTION } +MpReturnCode mp__PacketFromDynamicProto__PKc_PKc_i(const char* type_name, const char* serialized_proto, int size, + absl::Status** status_out, mediapipe::Packet** packet_out) { + TRY_ALL + auto status_or_packet = mediapipe::packet_internal::PacketFromDynamicProto(type_name, std::string(serialized_proto, size)); + *status_out = new absl::Status{status_or_packet.status()}; + if (!status_or_packet.ok()) { + *packet_out = nullptr; + } else { + *packet_out = new mediapipe::Packet{status_or_packet.value()}; + } + RETURN_CODE(MpReturnCode::Success); + CATCH_ALL +} + +MpReturnCode mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(const char* type_name, const char* serialized_proto, int size, + int64 timestampMicrosec, + absl::Status** status_out, mediapipe::Packet** packet_out) { + TRY_ALL + auto status_or_packet = mediapipe::packet_internal::PacketFromDynamicProto(type_name, std::string(serialized_proto, size)); + *status_out = new absl::Status{status_or_packet.status()}; + if (!status_or_packet.ok()) { + *packet_out = nullptr; + } else { + *packet_out = new mediapipe::Packet{status_or_packet.value().At(mediapipe::Timestamp(timestampMicrosec))}; + } + RETURN_CODE(MpReturnCode::Success); + CATCH_ALL +} + +MpReturnCode mp_Packet__GetProto(mediapipe::Packet* packet, mp_api::SerializedProto* serialized_proto) { + TRY_ALL + const auto& proto = packet->GetProtoMessageLite(); + SerializeProto(proto, serialized_proto); + RETURN_CODE(MpReturnCode::Success); + CATCH_ALL +} + +MpReturnCode mp_Packet__ValidateAsProtoMessageLite(mediapipe::Packet* packet, absl::Status** status_out) { + TRY + *status_out = new absl::Status{packet->ValidateAsProtoMessageLite()}; + RETURN_CODE(MpReturnCode::Success); + CATCH_EXCEPTION +} + /** PacketMap */ MpReturnCode mp_PacketMap__(PacketMap** packet_map_out) { TRY diff --git a/mediapipe_api/framework/packet.h b/mediapipe_api/framework/packet.h index 8e83c2608..7f4bbc877 100644 --- a/mediapipe_api/framework/packet.h +++ b/mediapipe_api/framework/packet.h @@ -38,7 +38,6 @@ MP_CAPI(MpReturnCode) mp_Packet__(mediapipe::Packet** packet_out); MP_CAPI(void) mp_Packet__delete(mediapipe::Packet* packet); MP_CAPI(MpReturnCode) mp_Packet__At__Rt(mediapipe::Packet* packet, mediapipe::Timestamp* timestamp, mediapipe::Packet** packet_out); MP_CAPI(bool) mp_Packet__IsEmpty(mediapipe::Packet* packet); -MP_CAPI(MpReturnCode) mp_Packet__ValidateAsProtoMessageLite(mediapipe::Packet* packet, absl::Status** status_out); MP_CAPI(MpReturnCode) mp_Packet__Timestamp(mediapipe::Packet* packet, mediapipe::Timestamp** timestamp_out); MP_CAPI(int64) mp_Packet__TimestampMicroseconds(mediapipe::Packet* packet); MP_CAPI(MpReturnCode) mp_Packet__DebugString(mediapipe::Packet* packet, const char** str_out); @@ -103,6 +102,15 @@ MP_CAPI(MpReturnCode) mp_Packet__ConsumeString(mediapipe::Packet* packet, absl:: MP_CAPI(MpReturnCode) mp_Packet__ConsumeByteString(mediapipe::Packet* packet, absl::Status** status_out, const char** value_out, int* size_out); MP_CAPI(MpReturnCode) mp_Packet__ValidateAsString(mediapipe::Packet* packet, absl::Status** status_out); +// proto +MP_CAPI(MpReturnCode) mp__PacketFromDynamicProto__PKc_PKc_i(const char* type_name, const char* serialized_proto, int size, + absl::Status** status_out, mediapipe::Packet** packet_out); +MP_CAPI(MpReturnCode) mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(const char* type_name, const char* serialized_proto, int size, + int64 timestampMicrosec, + absl::Status** status_out, mediapipe::Packet** packet_out); +MP_CAPI(MpReturnCode) mp_Packet__GetProto(mediapipe::Packet* packet, mp_api::SerializedProto* serialized_proto); +MP_CAPI(MpReturnCode) mp_Packet__ValidateAsProtoMessageLite(mediapipe::Packet* packet, absl::Status** status_out); + /** PacketMap API */ MP_CAPI(MpReturnCode) mp_PacketMap__(PacketMap** packet_map_out); MP_CAPI(void) mp_PacketMap__delete(PacketMap* packet_map);