Skip to content

Commit

Permalink
feat: implement Packet.GetProtoList (#1080)
Browse files Browse the repository at this point in the history
* feat: implement Packet.GetProtoList

* perf: avoid allocation
  • Loading branch information
homuler committed Dec 23, 2023
1 parent df48262 commit 0a302b2
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ public void Dispose()
{
var protos = new List<T>(_size);

Deserialize(parser, protos);

return protos;
}

/// <summary>
/// Deserializes the data as a list of <typeparamref name="T" />.
/// </summary>
/// <param name="protos">A list of <typeparamref name="T" /> to populate</param>
public void Deserialize<T>(pb::MessageParser<T> parser, List<T> protos) where T : pb::IMessage<T>
{
protos.Clear();

unsafe
{
var protoPtr = (SerializedProto*)_data;
Expand All @@ -37,8 +50,6 @@ public void Dispose()
protos.Add(serializedProto.Deserialize(parser));
}
}

return protos;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ public ImageFrame GetImageFrame()
/// On some platforms (e.g. Windows), it will abort the process when <see cref="MediaPipeException"/> should be thrown.
/// </remarks>
/// <exception cref="MediaPipeException">
/// If the <see cref="Packet"/> doesn't contain <see langword="int"/> data.
/// If the <see cref="Packet"/> doesn't contain <see langword="int"/> data.
/// </exception>
public int GetInt()
{
Expand All @@ -499,7 +499,7 @@ public int GetInt()
/// </exception>
public T GetProto<T>(MessageParser<T> parser) where T : IMessage<T>
{
UnsafeNativeMethods.mp_Packet__GetProto(mpPtr, out var value).Assert();
UnsafeNativeMethods.mp_Packet__GetProtoMessageLite(mpPtr, out var value).Assert();

GC.KeepAlive(this);

Expand All @@ -509,6 +509,45 @@ public int GetInt()
return proto;
}

/// <summary>
/// Get the content of the <see cref="Packet"/> as a proto message list.
/// </summary>
/// <remarks>
/// On some platforms (e.g. Windows), it will abort the process when <see cref="MediaPipeException"/> should be thrown.
/// </remarks>
/// <exception cref="MediaPipeException">
/// If the <see cref="Packet"/> doesn't contain a proto message list.
/// </exception>
public List<T> GetProtoList<T>(MessageParser<T> parser) where T : IMessage<T>
{
var value = new List<T>();
GetProtoList(parser, value);

return value;
}

/// <summary>
/// Get the content of the <see cref="Packet"/> as a proto message list.
/// </summary>
/// <remarks>
/// On some platforms (e.g. Windows), it will abort the process when <see cref="MediaPipeException"/> should be thrown.
/// </remarks>
/// <param name="value">
/// The <see cref="List{T}"/> to be filled with the content of the <see cref="Packet"/>.
/// </param>
/// <exception cref="MediaPipeException">
/// If the <see cref="Packet"/> doesn't contain a proto message list.
/// </exception>
public void GetProtoList<T>(MessageParser<T> parser, List<T> value) where T : IMessage<T>
{
UnsafeNativeMethods.mp_Packet__GetVectorOfProtoMessageLite(mpPtr, out var serializedProtoVector).Assert();

GC.KeepAlive(this);

serializedProtoVector.Deserialize(parser, value);
serializedProtoVector.Dispose();
}

/// <summary>
/// Validate if the content of the <see cref="Packet"/> is a boolean.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,10 @@ internal static partial class UnsafeNativeMethods
out IntPtr status, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetProto(IntPtr packet, out SerializedProto value);
public static extern MpReturnCode mp_Packet__GetProtoMessageLite(IntPtr packet, out SerializedProto value);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetVectorOfProtoMessageLite(IntPtr packet, out SerializedProtoVector value);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsProtoMessageLite(IntPtr packet, out IntPtr status);
Expand Down
11 changes: 11 additions & 0 deletions mediapipe_api/external/protobuf.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ inline void SerializeProtoVector(const std::vector<T>& proto_vec, mp_api::Struct
serialized_proto_vector->size = static_cast<int>(vec_size);
}

inline void SerializeProtoVector(const std::vector<const google::protobuf::MessageLite*>& proto_vec, mp_api::StructArray<mp_api::SerializedProto>* serialized_proto_vector) {
auto vec_size = proto_vec.size();
auto data = new mp_api::SerializedProto[vec_size];

for (auto i = 0; i < vec_size; ++i) {
SerializeProto(*proto_vec[i], &data[i]);
}
serialized_proto_vector->data = data;
serialized_proto_vector->size = static_cast<int>(vec_size);
}

template <class T>
inline T ParseFromStringAsProto(const char* serialized_data, int size) {
T proto;
Expand Down
14 changes: 13 additions & 1 deletion mediapipe_api/framework/packet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,14 +400,26 @@ MpReturnCode mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(const char* type_name,
CATCH_ALL
}

MpReturnCode mp_Packet__GetProto(mediapipe::Packet* packet, mp_api::SerializedProto* serialized_proto) {
MpReturnCode mp_Packet__GetProtoMessageLite(mediapipe::Packet* packet, mp_api::SerializedProto* serialized_proto) {
TRY_ALL
const auto& proto = packet->GetProtoMessageLite();
SerializeProto(proto, serialized_proto);
RETURN_CODE(MpReturnCode::Success);
CATCH_ALL
}

MpReturnCode mp_Packet__GetVectorOfProtoMessageLite(mediapipe::Packet* packet, mp_api::StructArray<mp_api::SerializedProto>* value_out) {
TRY_ALL
const auto status_or_vec = packet->GetVectorOfProtoMessageLitePtrs();
if (!status_or_vec.ok()) {
LOG(FATAL) << status_or_vec.status().message();
}

SerializeProtoVector(status_or_vec.value(), value_out);
RETURN_CODE(MpReturnCode::Success);
CATCH_ALL
}

MpReturnCode mp_Packet__ValidateAsProtoMessageLite(mediapipe::Packet* packet, absl::Status** status_out) {
TRY
*status_out = new absl::Status{packet->ValidateAsProtoMessageLite()};
Expand Down
3 changes: 2 additions & 1 deletion mediapipe_api/framework/packet.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ MP_CAPI(MpReturnCode) mp__PacketFromDynamicProto__PKc_PKc_i(const char* type_nam
MP_CAPI(MpReturnCode) mp__PacketFromDynamicProto_At__PKc_PKc_i_ll(const char* type_name, const char* serialized_proto, int size,
int64 timestampMicrosec,
absl::Status** status_out, mediapipe::Packet** packet_out);
MP_CAPI(MpReturnCode) mp_Packet__GetProto(mediapipe::Packet* packet, mp_api::SerializedProto* serialized_proto);
MP_CAPI(MpReturnCode) mp_Packet__GetProtoMessageLite(mediapipe::Packet* packet, mp_api::SerializedProto* value_out);
MP_CAPI(MpReturnCode) mp_Packet__GetVectorOfProtoMessageLite(mediapipe::Packet* packet, mp_api::StructArray<mp_api::SerializedProto>* value_out);
MP_CAPI(MpReturnCode) mp_Packet__ValidateAsProtoMessageLite(mediapipe::Packet* packet, absl::Status** status_out);

/** PacketMap API */
Expand Down

0 comments on commit 0a302b2

Please sign in to comment.