Skip to content

Commit

Permalink
fix: check Packet types at compile-time more properly (#509)
Browse files Browse the repository at this point in the history
* fix!(plugin): default constructors of Packet classes cause memory leak

* test Packet#At
  • Loading branch information
homuler committed Apr 4, 2022
1 parent 2497f0c commit d03895d
Show file tree
Hide file tree
Showing 35 changed files with 431 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,19 @@ namespace Mediapipe
{
public abstract class MpResourceHandle : DisposableObject, IMpResourceHandle
{
protected IntPtr ptr;
private IntPtr _ptr = IntPtr.Zero;
protected IntPtr ptr
{
get => _ptr;
set
{
if (value != IntPtr.Zero && OwnsResource())
{
throw new InvalidOperationException($"This object owns another resource");
}
_ptr = value;
}
}

protected MpResourceHandle(bool isOwner = true) : this(IntPtr.Zero, isOwner) { }

Expand Down Expand Up @@ -40,7 +52,7 @@ public void ReleaseMpResource()

public bool OwnsResource()
{
return isOwner && ptr != IntPtr.Zero;
return isOwner && IsResourcePresent();
}
#endregion

Expand Down Expand Up @@ -80,5 +92,10 @@ protected string MarshalStringFromNative(StringOutFunc f)

return str;
}

protected bool IsResourcePresent()
{
return ptr != IntPtr.Zero;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ public Status ObserveOutputStream(string streamName, NativePacketCallback native
return new Status(statusPtr);
}

public Status ObserveOutputStream<TPacket, TValue>(string streamName, PacketCallback<TPacket, TValue> packetCallback, bool observeTimestampBounds, out GCHandle callbackHandle) where TPacket : Packet<TValue>
public Status ObserveOutputStream<TPacket, TValue>(string streamName, PacketCallback<TPacket, TValue> packetCallback, bool observeTimestampBounds, out GCHandle callbackHandle) where TPacket : Packet<TValue>, new()
{
NativePacketCallback nativePacketCallback = (IntPtr _, IntPtr packetPtr) =>
{
Status status = null;
try
{
var packet = (TPacket)Activator.CreateInstance(typeof(TPacket), packetPtr, false);
var packet = Packet<TValue>.Create<TPacket>(packetPtr, false);
status = packetCallback(packet);
}
catch (Exception e)
Expand All @@ -100,7 +100,7 @@ public Status ObserveOutputStream(string streamName, NativePacketCallback native
return ObserveOutputStream(streamName, nativePacketCallback, observeTimestampBounds);
}

public Status ObserveOutputStream<TPacket, TValue>(string streamName, PacketCallback<TPacket, TValue> packetCallback, out GCHandle callbackHandle) where TPacket : Packet<TValue>
public Status ObserveOutputStream<TPacket, TValue>(string streamName, PacketCallback<TPacket, TValue> packetCallback, out GCHandle callbackHandle) where TPacket : Packet<TValue>, new()
{
return ObserveOutputStream(streamName, packetCallback, false, out callbackHandle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ namespace Mediapipe
{
public class Anchor3dVectorPacket : Packet<List<Anchor3d>>
{
public Anchor3dVectorPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="Anchor3dVectorPacket" /> instance.
/// </summary>
public Anchor3dVectorPacket() : base(true) { }

public Anchor3dVectorPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public Anchor3dVectorPacket(Anchor3d[] value) : base()
Expand All @@ -27,6 +31,11 @@ public Anchor3dVectorPacket(Anchor3d[] value, Timestamp timestamp) : base()
this.ptr = ptr;
}

public Anchor3dVectorPacket At(Timestamp timestamp)
{
return At<Anchor3dVectorPacket>(timestamp);
}

public override List<Anchor3d> Get()
{
UnsafeNativeMethods.mp_Packet__GetAnchor3dVector(mpPtr, out var anchorVector).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ namespace Mediapipe
{
public class BoolPacket : Packet<bool>
{
public BoolPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="BoolPacket" /> instance.
/// </summary>
public BoolPacket() : base(true) { }

public BoolPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

Expand All @@ -27,6 +30,11 @@ public BoolPacket(bool value, Timestamp timestamp) : base()
this.ptr = ptr;
}

public BoolPacket At(Timestamp timestamp)
{
return At<BoolPacket>(timestamp);
}

public override bool Get()
{
UnsafeNativeMethods.mp_Packet__GetBool(mpPtr, out var value).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,18 @@ namespace Mediapipe
{
public class ClassificationListPacket : Packet<ClassificationList>
{
public ClassificationListPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="ClassificationListPacket" /> instance.
/// </summary>
public ClassificationListPacket() : base(true) { }

public ClassificationListPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public ClassificationListPacket At(Timestamp timestamp)
{
return At<ClassificationListPacket>(timestamp);
}

public override ClassificationList Get()
{
UnsafeNativeMethods.mp_Packet__GetClassificationList(mpPtr, out var serializedProto).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@ namespace Mediapipe
{
public class ClassificationListVectorPacket : Packet<List<ClassificationList>>
{
public ClassificationListVectorPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="ClassificationListVectorPacket" /> instance.
/// </summary>
public ClassificationListVectorPacket() : base(true) { }

public ClassificationListVectorPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public ClassificationListVectorPacket At(Timestamp timestamp)
{
return At<ClassificationListVectorPacket>(timestamp);
}

public override List<ClassificationList> Get()
{
UnsafeNativeMethods.mp_Packet__GetClassificationListVector(mpPtr, out var serializedProtoVector).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,18 @@ namespace Mediapipe
{
public class DetectionPacket : Packet<Detection>
{
public DetectionPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="DetectionPacket" /> instance.
/// </summary>
public DetectionPacket() : base(true) { }

public DetectionPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public DetectionPacket At(Timestamp timestamp)
{
return At<DetectionPacket>(timestamp);
}

public override Detection Get()
{
UnsafeNativeMethods.mp_Packet__GetDetection(mpPtr, out var serializedProto).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@ namespace Mediapipe
{
public class DetectionVectorPacket : Packet<List<Detection>>
{
public DetectionVectorPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="DetectionVectorPacket" /> instance.
/// </summary>
public DetectionVectorPacket() : base(true) { }

public DetectionVectorPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public DetectionVectorPacket At(Timestamp timestamp)
{
return At<DetectionVectorPacket>(timestamp);
}

public override List<Detection> Get()
{
UnsafeNativeMethods.mp_Packet__GetDetectionVector(mpPtr, out var serializedProtoVector).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,18 @@ namespace Mediapipe
{
public class FaceGeometryPacket : Packet<FaceGeometry.FaceGeometry>
{
public FaceGeometryPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="FaceGeometryPacket" /> instance.
/// </summary>
public FaceGeometryPacket() : base(true) { }

public FaceGeometryPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public FaceGeometryPacket At(Timestamp timestamp)
{
return At<FaceGeometryPacket>(timestamp);
}

public override FaceGeometry.FaceGeometry Get()
{
UnsafeNativeMethods.mp_Packet__GetFaceGeometry(mpPtr, out var serializedProto).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@ namespace Mediapipe
{
public class FaceGeometryVectorPacket : Packet<List<FaceGeometry.FaceGeometry>>
{
public FaceGeometryVectorPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="FaceGeometryVectorPacket" /> instance.
/// </summary>
public FaceGeometryVectorPacket() : base(true) { }

public FaceGeometryVectorPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public FaceGeometryVectorPacket At(Timestamp timestamp)
{
return At<FaceGeometryVectorPacket>(timestamp);
}

public override List<FaceGeometry.FaceGeometry> Get()
{
UnsafeNativeMethods.mp_Packet__GetFaceGeometryVector(mpPtr, out var serializedProtoVector).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ public int length
}
}

public FloatArrayPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="FloatArrayPacket" /> instance.
/// </summary>
public FloatArrayPacket() : base(true) { }

public FloatArrayPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

Expand All @@ -45,6 +48,11 @@ public FloatArrayPacket(float[] value, Timestamp timestamp) : base()
length = value.Length;
}

public FloatArrayPacket At(Timestamp timestamp)
{
return At<FloatArrayPacket>(timestamp);
}

public override float[] Get()
{
if (length < 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ namespace Mediapipe
{
public class FloatPacket : Packet<float>
{
public FloatPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="FloatPacket" /> instance.
/// </summary>
public FloatPacket() : base(true) { }

public FloatPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

Expand All @@ -27,6 +30,11 @@ public FloatPacket(float value, Timestamp timestamp) : base()
this.ptr = ptr;
}

public FloatPacket At(Timestamp timestamp)
{
return At<FloatPacket>(timestamp);
}

public override float Get()
{
UnsafeNativeMethods.mp_Packet__GetFloat(mpPtr, out var value).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,18 @@ namespace Mediapipe
{
public class FrameAnnotationPacket : Packet<FrameAnnotation>
{
public FrameAnnotationPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="FrameAnnotationPacket" /> instance.
/// </summary>
public FrameAnnotationPacket() : base(true) { }

public FrameAnnotationPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public FrameAnnotationPacket At(Timestamp timestamp)
{
return At<FrameAnnotationPacket>(timestamp);
}

public override FrameAnnotation Get()
{
UnsafeNativeMethods.mp_Packet__GetFrameAnnotation(mpPtr, out var serializedProto).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ namespace Mediapipe
{
public class GpuBufferPacket : Packet<GpuBuffer>
{
public GpuBufferPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="GpuBufferPacket" /> instance.
/// </summary>
public GpuBufferPacket() : base(true) { }

public GpuBufferPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public GpuBufferPacket(GpuBuffer gpuBuffer) : base()
Expand All @@ -30,6 +34,11 @@ public GpuBufferPacket(GpuBuffer gpuBuffer, Timestamp timestamp)
this.ptr = ptr;
}

public GpuBufferPacket At(Timestamp timestamp)
{
return At<GpuBufferPacket>(timestamp);
}

public override GpuBuffer Get()
{
UnsafeNativeMethods.mp_Packet__GetGpuBuffer(mpPtr, out var gpuBufferPtr).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ namespace Mediapipe
{
public class ImageFramePacket : Packet<ImageFrame>
{
public ImageFramePacket() : base() { }
/// <summary>
/// Creates an empty <see cref="ImageFramePacket" /> instance.
/// </summary>
public ImageFramePacket() : base(true) { }

public ImageFramePacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

Expand All @@ -31,6 +34,11 @@ public ImageFramePacket(ImageFrame imageFrame, Timestamp timestamp) : base()
this.ptr = ptr;
}

public ImageFramePacket At(Timestamp timestamp)
{
return At<ImageFramePacket>(timestamp);
}

public override ImageFrame Get()
{
UnsafeNativeMethods.mp_Packet__GetImageFrame(mpPtr, out var imageFramePtr).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ namespace Mediapipe
{
public class IntPacket : Packet<int>
{
public IntPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="IntPacket" /> instance.
/// </summary>
public IntPacket() : base(true) { }

public IntPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

Expand All @@ -27,6 +30,11 @@ public IntPacket(int value, Timestamp timestamp) : base()
this.ptr = ptr;
}

public IntPacket At(Timestamp timestamp)
{
return At<IntPacket>(timestamp);
}

public override int Get()
{
UnsafeNativeMethods.mp_Packet__GetInt(mpPtr, out var value).Assert();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,18 @@ namespace Mediapipe
{
public class LandmarkListPacket : Packet<LandmarkList>
{
public LandmarkListPacket() : base() { }
/// <summary>
/// Creates an empty <see cref="LandmarkListPacket" /> instance.
/// </summary>
public LandmarkListPacket() : base(true) { }

public LandmarkListPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public LandmarkListPacket At(Timestamp timestamp)
{
return At<LandmarkListPacket>(timestamp);
}

public override LandmarkList Get()
{
UnsafeNativeMethods.mp_Packet__GetLandmarkList(mpPtr, out var serializedProto).Assert();
Expand Down
Loading

0 comments on commit d03895d

Please sign in to comment.