Skip to content

Commit

Permalink
feat: add FloatVectorPacket and MatrixPacket (#767)
Browse files Browse the repository at this point in the history
* matrix classification via tflite model (#656)

* cc: matrix_frame as input to graph

- "matrix_frame" as name in order to avoid confusion with matrix.cc
- matrix_frame is a 2D input data modality that gets converted to Eigen::MatrixXf internally
- suited for non-image input to tflite models

* cc: float_vector_frame as input to graph

- "float_vector_frame" as name in order to avoid confusion with FloatArrayPacket
- float_vector_frame is a 1D output data modality that accepts std::vector<float> outputs from a mediapipe / tflite graph
- suited for tflite classification results packaged as vector of floats

* MatrixFramePacket - c# helper functions

- send MatrixData to C++ as byte array

* FloatVectorFramePacket - c# helper functions

- return std::vector<float> from C++ to Unity as List<float>

* Unity: Matrix Classification - Example scene

* Matrix Classification.cs

- driver code for the newly added MatrixFramePacket and FloatVectorFramePacket
- feeds an example matrix of size [ 2 x 3 ] into a mediapipe graph
- the graph runs a simple tflite model (adds +1 to every input)
- then the graph returns the result back to Unity as List<float>

- only tested on Unity-Editor-Mode on Windows 10 Pro

* refactor: rename FloatVectorFrame -> FloatVector

- rename variables
- rename cs files
- rename cc files

* refactor: rename MatrixFrame -> Matrix

- rename variables
- rename cs files
- rename cc files matrix_frame -> matrix_data
-> avoid matrix.cc as it is already used as a name in mediapipe

* move MatrixClassification example scene to Tutorials

- does not represent an official solution
- not sure where else to place this
- MatrixClassification.cs is an important example for showcasing the usage of a tflite model with a matrix data input

* GetArrayPtr() - change access to private

* MatrixPacket: accept MatrixData as input

- before it was byte[]

* add license

* move native functions to Packet_Unsafe

- delete FloatVector_Unsafe.cs

* float_vector.cc -> faster vector allocation

* float_vector.cc remove unused function - delete(...)

* float_vector.h - remove unused headers

* refactor: float_vector.cc

TODO:
- implement GetFloatVector with vector size as argument

* removed unused headers

* refactor: float_vector.h

* refactor: apply autoformatter on cc files

- using format file ".clang-format" in project root

* refactor: mp__MakeMatrixFramePacket_At__PA_i_Rt -> mp__MakeMatrixPacket_At__PKc_i_Rt

* FloatVectorPacketTest added

- build similar to FloatArrayPacketTest
- not yet tested

* fix: float_vector.cc

* fix: MatrixPacket.cs

* fix: Test: FloatVectorPacketTest - Consume_ShouldThrowNotSupportedException

* MatrixPacketTest - add

- all tests involving packet.Get() do not work
- function is not yet implemented

* fix: Make MatrixClassification.cs run on Android

- adding StreamingAssets to ResourceManager

[skip actions]

* Update mediapipe_api/framework/formats/matrix_data.h

Co-authored-by: Junrou Nishida <[email protected]>

* Apply suggestions from code review

Co-authored-by: Junrou Nishida <[email protected]>

* float_vector - return vector size (+2 squashed commit)

Squashed commit:

[e409b05] refactor: vector_float.cc

- naming aligns with files like packet.cc

[bad3cd6] float_vector - return vector size

* fix: matrix_data.cc - wrong func name (+3 squashed commit)

Squashed commit:

[9245a37] fix: Revert "Apply suggestions from code review"

- the below mentioned commit is not working
- return value of inline function is invalid
-> probably due to inline function

This reverts commit c374d61.

[f597e83] fix: remove duplicate cpp func

[6def10c] fix: semicolon omitted

* FloatVectorPacket - replace list by array

fix: FloatVectorPacket

[skip actions] (+1 squashed commits)

Squashed commits:

[69302b1] FloatVectorPacket - replace list by array

- list is slow

* Add license headers

[skip actions]

* Remove Tutorial Scene: MatrixClassification

as per request:
- deleted demo / tutorial scene that showcasts a simple tflite graph
[skip actions]

* fix: MatrixPacket tests

- new GetMatrix function

Caveat:
- MatrixPacket: Consume throws NotSupportedException()
-> not sure if this is a useful test, but such tests exists in similar classes as well

(cherry picked from commit 707af5f454e87312a86b60deaebec18463e47ded)

Co-authored-by: Junrou Nishida <[email protected]>

* refactor: move float vector packet API

* FloatVectorPacket#Get returns List<float>

* copy float arrays when getting the value

* refactor MatrixPacket

* refactor: use ParseFromStringAsProto

* test: replace DebugTypeName tests with ValidateAsType tests

Co-authored-by: Martin Garbade <[email protected]>
  • Loading branch information
homuler and mgarbade committed Oct 15, 2022
1 parent e0a68a8 commit 391d7d9
Show file tree
Hide file tree
Showing 29 changed files with 744 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,23 @@ public override float[] Get()
}

var result = new float[length];
UnsafeNativeMethods.mp_Packet__GetFloatArray_i(mpPtr, length, out var arrayPtr).Assert();
GC.KeepAlive(this);

unsafe
{
var src = (float*)GetArrayPtr();
var src = (float*)arrayPtr;

for (var i = 0; i < result.Length; i++)
{
result[i] = *src++;
}
}

UnsafeNativeMethods.delete_array__Pf(arrayPtr);
return result;
}

public IntPtr GetArrayPtr()
{
UnsafeNativeMethods.mp_Packet__GetFloatArray(mpPtr, out var value).Assert();
GC.KeepAlive(this);
return value;
}

public override StatusOr<float[]> Consume()
{
throw new NotSupportedException();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) 2021 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;

namespace Mediapipe
{
[StructLayout(LayoutKind.Sequential)]
internal readonly struct FloatVector
{
private readonly IntPtr _data;
private readonly int _size;

public void Dispose()
{
UnsafeNativeMethods.delete_array__Pf(_data);
}

public List<float> Copy()
{
var data = new List<float>(_size);

unsafe
{
var floatPtr = (float*)_data;

for (var i = 0; i < _size; i++)
{
data.Add(*floatPtr++);
}
}
return data;
}
}

public class FloatVectorPacket : Packet<List<float>>
{
/// <summary>
/// Creates an empty <see cref="FloatVectorPacket" /> instance.
/// </summary>
///
public FloatVectorPacket() : base(true) { }

[UnityEngine.Scripting.Preserve]
public FloatVectorPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }


public FloatVectorPacket(float[] value) : base()
{
UnsafeNativeMethods.mp__MakeFloatVectorPacket__Pf_i(value, value.Length, out var ptr).Assert();
this.ptr = ptr;
}

public FloatVectorPacket(float[] value, Timestamp timestamp) : base()
{
UnsafeNativeMethods.mp__MakeFloatVectorPacket_At__Pf_i_Rt(value, value.Length, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
this.ptr = ptr;
}

public FloatVectorPacket(List<float> value) : base()
{
UnsafeNativeMethods.mp__MakeFloatVectorPacket__Pf_i(value.ToArray(), value.Count, out var ptr).Assert();
this.ptr = ptr;
}

public FloatVectorPacket(List<float> value, Timestamp timestamp) : base()
{
UnsafeNativeMethods.mp__MakeFloatVectorPacket_At__Pf_i_Rt(value.ToArray(), value.Count, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
this.ptr = ptr;
}

public FloatVectorPacket At(Timestamp timestamp)
{
return At<FloatVectorPacket>(timestamp);
}

public override List<float> Get()
{
UnsafeNativeMethods.mp_Packet__GetFloatVector(mpPtr, out var floatVector).Assert();
GC.KeepAlive(this);

var result = floatVector.Copy();
floatVector.Dispose();
return result;
}

public override StatusOr<List<float>> Consume()
{
throw new NotSupportedException();
}

public override Status ValidateAsType()
{
UnsafeNativeMethods.mp_Packet__ValidateAsFloatVector(mpPtr, out var statusPtr).Assert();

GC.KeepAlive(this);
return new Status(statusPtr);
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) 2021 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using Google.Protobuf;
using System;

namespace Mediapipe
{
public class MatrixPacket : Packet<MatrixData>
{
/// <summary>
/// Creates an empty <see cref="MatrixPacket" /> instance.
/// </summary>
public MatrixPacket() : base(true) { }

[UnityEngine.Scripting.Preserve]
public MatrixPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public MatrixPacket(MatrixData matrixData) : base()
{
var value = matrixData.ToByteArray();
UnsafeNativeMethods.mp__MakeMatrixPacket__PKc_i(value, value.Length, out var ptr).Assert();
this.ptr = ptr;
}

public MatrixPacket(MatrixData matrixData, Timestamp timestamp) : base()
{
var value = matrixData.ToByteArray();
UnsafeNativeMethods.mp__MakeMatrixPacket_At__PKc_i_Rt(value, value.Length, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
this.ptr = ptr;
}

public MatrixPacket At(Timestamp timestamp)
{
return At<MatrixPacket>(timestamp);
}

public override MatrixData Get()
{
UnsafeNativeMethods.mp_Packet__GetMatrix(mpPtr, out var serializedMatrixData).Assert();
GC.KeepAlive(this);

var matrixData = serializedMatrixData.Deserialize(MatrixData.Parser);
serializedMatrixData.Dispose();

return matrixData;
}

public override StatusOr<MatrixData> Consume()
{
throw new NotSupportedException();
}

public override Status ValidateAsType()
{
UnsafeNativeMethods.mp_Packet__ValidateAsMatrix(mpPtr, out var statusPtr).Assert();

GC.KeepAlive(this);
return new Status(statusPtr);
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ internal static partial class UnsafeNativeMethods
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern void delete_array__PKc(IntPtr str);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern void delete_array__Pf(IntPtr str);

#region String
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern void std_string__delete(IntPtr str);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2021 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System;
using System.Runtime.InteropServices;

namespace Mediapipe
{
internal static partial class UnsafeNativeMethods
{
#region Packet
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeMatrixPacket__PKc_i(byte[] serializedMatrixData, int size, out IntPtr packet_out);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeMatrixPacket_At__PKc_i_Rt(byte[] serializedMatrixData, int size, IntPtr timestamp, out IntPtr packet_out);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsMatrix(IntPtr packet, out IntPtr status);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetMatrix(IntPtr packet, out SerializedProto serializedProto);

#endregion
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,26 @@ internal static partial class UnsafeNativeMethods
public static extern MpReturnCode mp__MakeFloatArrayPacket_At__Pf_i_Rt(float[] value, int size, IntPtr timestamp, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetFloatArray(IntPtr packet, out IntPtr value);
public static extern MpReturnCode mp_Packet__GetFloatArray_i(IntPtr packet, int size, out IntPtr value);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsFloatArray(IntPtr packet, out IntPtr status);
#endregion

#region FloatVector
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeFloatVectorPacket__Pf_i(float[] value, int size, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeFloatVectorPacket_At__Pf_i_Rt(float[] value, int size, IntPtr timestamp, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetFloatVector(IntPtr packet, out FloatVector value);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsFloatVector(IntPtr packet, out IntPtr status);
#endregion

#region String
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeStringPacket__PKc(string value, out IntPtr packet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ public void Consume_ShouldThrowNotSupportedException()
}
#endregion

#region #DebugTypeName
#region #ValidateAsType
[Test]
public void DebugTypeName_ShouldReturnBool_When_ValueIsSet()
public void ValidateAsType_ShouldReturnOk_When_ValueIsSet()
{
using (var packet = new BoolPacket(true))
{
Assert.AreEqual("bool", packet.DebugTypeName());
Assert.True(packet.ValidateAsType().Ok());
}
}
#endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,14 @@ public void Consume_ShouldThrowNotSupportedException()
}
#endregion

#region #DebugTypeName
#region #ValidateAsType
[Test]
public void DebugTypeName_ShouldReturnFloat_When_ValueIsSet()
public void ValidateAsType_ShouldReturnOk_When_ValueIsSet()
{
float[] array = { 0.01f };
using (var packet = new FloatArrayPacket(array))
{
#if UNITY_EDITOR_WIN
Assert.AreEqual("float [0]", packet.DebugTypeName());
#else
Assert.AreEqual("float []", packet.DebugTypeName());
#endif
Assert.True(packet.ValidateAsType().Ok());
}
}
#endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ public void Consume_ShouldThrowNotSupportedException()
}
#endregion

#region #DebugTypeName
#region #ValidateAsType
[Test]
public void DebugTypeName_ShouldReturnFloat_When_ValueIsSet()
public void ValidateAsType_ShouldReturnOk_When_ValueIsSet()
{
using (var packet = new FloatPacket(0.01f))
{
Assert.AreEqual("float", packet.DebugTypeName());
Assert.True(packet.ValidateAsType().Ok());
}
}
#endregion
Expand Down
Loading

0 comments on commit 391d7d9

Please sign in to comment.