Skip to content

Commit

Permalink
feat: implement Packet.CreateColMajorMatrix (#1133)
Browse files Browse the repository at this point in the history
* refactor: matrix_data -> matrix

* feat: port Matrix

* feat: implement Packet.CreateColMajorMatrix

* refactor: remove stale Matrix-related functions
  • Loading branch information
homuler committed Jan 27, 2024
1 parent 6eb8026 commit af74bb2
Show file tree
Hide file tree
Showing 15 changed files with 359 additions and 83 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) 2023 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System;

namespace Mediapipe
{
public readonly struct Matrix
{
public enum Layout
{
ColMajor = 0,
RowMajor = 1,
}

internal readonly float[] data;
public readonly int rows;
public readonly int cols;
internal readonly Layout layout;

public Matrix(float[] data, int rows, int cols, Layout layout = Layout.ColMajor)
{
if (rows * cols != data.Length)
{
throw new ArgumentException($"Matrix size mismatch ({rows}x{cols} != {data.Length})");
}

this.data = data;
this.rows = rows;
this.cols = cols;
this.layout = layout;
}

internal Matrix(NativeMatrix nativeMatrix) : this(nativeMatrix.AsReadOnlySpan().ToArray(), nativeMatrix.rows, nativeMatrix.cols, nativeMatrix.layout == 0 ? Layout.ColMajor : Layout.RowMajor)
{ }

internal static void Copy(NativeMatrix source, ref Matrix destination)
{
if (destination.rows != source.rows || destination.cols != source.cols)
{
throw new ArgumentException($"Matrix size mismatch ({source.rows}x{source.cols} != {destination.rows}x{destination.cols})");
}

source.AsReadOnlySpan().CopyTo(destination.data);
var layout = source.layout == 0 ? Layout.ColMajor : Layout.RowMajor;

destination = new Matrix(destination.data, source.rows, source.cols, layout);
}

public readonly bool isColMajor => layout == Layout.ColMajor;
public readonly bool isRowMajor => layout == Layout.RowMajor;
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,56 @@ public static Packet<int> CreateIntAt(int value, long timestampMicrosec)
return new Packet<int>(ptr, true);
}

/// <summary>
/// Create a Matrix Packet.
/// </summary>
public static Packet<Matrix> CreateColMajorMatrix(float[] data, int row, int col)
{
UnsafeNativeMethods.mp__MakeColMajorMatrixPacket__Pf_i_i(data, row, col, out var ptr).Assert();

return new Packet<Matrix>(ptr, true);
}

/// <summary>
/// Create a Matrix Packet.
/// </summary>
public static Packet<Matrix> CreateColMajorMatrix(Matrix value)
{
if (!value.isColMajor)
{
throw new ArgumentException("Matrix must be col-major");
}
return CreateColMajorMatrix(value.data, value.rows, value.cols);
}

/// <summary>
/// Create a Matrix Packet.
/// </summary>
/// <param name="timestampMicrosec">
/// The timestamp of the packet.
/// </param>
public static Packet<Matrix> CreateColMajorMatrixAt(float[] data, int row, int col, long timestampMicrosec)
{
UnsafeNativeMethods.mp__MakeColMajorMatrixPacket_At__Pf_i_i_ll(data, row, col, timestampMicrosec, out var ptr).Assert();

return new Packet<Matrix>(ptr, true);
}

/// <summary>
/// Create a Matrix Packet.
/// </summary>
/// <param name="timestampMicrosec">
/// The timestamp of the packet.
/// </param>
public static Packet<Matrix> CreateColMajorMatrixAt(Matrix value, long timestampMicrosec)
{
if (!value.isColMajor)
{
throw new ArgumentException("Matrix must be col-major");
}
return CreateColMajorMatrixAt(value.data, value.rows, value.cols, timestampMicrosec);
}

/// <summary>
/// Create a MediaPipe protobuf message Packet.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,43 @@ public static int Get(this Packet<int> packet)
[Obsolete("Use Get instead")]
public static int GetInt(this Packet<int> packet) => Get(packet);

/// <summary>
/// Get the content of the <see cref="Packet"/> as a <see cref="Matrix"/> .
/// </summary>
/// <remarks>
/// On some platforms (e.g. Windows), it will abort the process when <see cref="MediaPipeException"/> should be thrown.
/// </remarks>
/// <param name="value">
/// The <see cref="Matrix"/> to be filled with the content of the <see cref="Packet"/>.
/// </param>
/// <exception cref="MediaPipeException">
/// If the <see cref="Packet"/> doesn't contain a mediapipe::Matrix data.
/// </exception>
public static void Get(this Packet<Matrix> packet, ref Matrix value)
{
UnsafeNativeMethods.mp_Packet__GetMpMatrix(packet.mpPtr, out var nativeMatrix).Assert();
GC.KeepAlive(packet);

Matrix.Copy(nativeMatrix, ref value);
}

/// <summary>
/// Get the content of the <see cref="Packet"/> as a <see cref="Matrix"/> .
/// </summary>
/// <remarks>
/// On some platforms (e.g. Windows), it will abort the process when <see cref="MediaPipeException"/> should be thrown.
/// </remarks>
/// <exception cref="MediaPipeException">
/// If the <see cref="Packet"/> doesn't contain a mediapipe::Matrix data.
/// </exception>
public static Matrix Get(this Packet<Matrix> packet)
{
UnsafeNativeMethods.mp_Packet__GetMpMatrix(packet.mpPtr, out var nativeMatrix).Assert();
GC.KeepAlive(packet);

return new Matrix(nativeMatrix);
}

/// <summary>
/// Get the content of the <see cref="Packet"/> as a proto message.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,20 @@ public static void Validate(this Packet<int> packet)
[Obsolete("Use Validate instead")]
public static void ValidateAsInt(this Packet<int> packet) => Validate(packet);

/// <summary>
/// Validate if the content of the <see cref="Packet"/> is mediapipe::Matrix.
/// </summary>
/// <exception cref="BadStatusException">
/// If the <see cref="Packet"/> doesn't contain mediapipe::Matrix.
/// </exception>
public static void Validate(this Packet<Matrix> packet)
{
UnsafeNativeMethods.mp_Packet__ValidateAsMatrix(packet.mpPtr, out var statusPtr).Assert();
GC.KeepAlive(packet);

Status.UnsafeAssertOk(statusPtr);
}

/// <summary>
/// Validate if the content of the <see cref="Packet"/> is a proto message.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2023 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System;
using System.Runtime.InteropServices;

namespace Mediapipe
{
[StructLayout(LayoutKind.Sequential)]
internal readonly struct NativeMatrix
{
private readonly IntPtr _data;
public readonly int rows;
public readonly int cols;
public readonly int layout;

public void Dispose()
{
UnsafeNativeMethods.mp_api_Matrix__delete(this);
}

public ReadOnlySpan<float> AsReadOnlySpan()
{
unsafe
{
return new ReadOnlySpan<float>((float*)_data, rows * cols);
}
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@ internal static partial class UnsafeNativeMethods
{
#region Packet
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeMatrixPacket__PKc_i(byte[] serializedMatrixData, int size, out IntPtr packet_out);
public static extern MpReturnCode mp_Packet__ValidateAsMatrix(IntPtr packet, out IntPtr status);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeMatrixPacket_At__PKc_i_Rt(byte[] serializedMatrixData, int size, IntPtr timestamp, out IntPtr packet_out);
public static extern MpReturnCode mp_Packet__GetMpMatrix(IntPtr packet, out NativeMatrix matrix);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsMatrix(IntPtr packet, out IntPtr status);
public static extern MpReturnCode mp__MakeColMajorMatrixPacket__Pf_i_i(float[] data, int rows, int cols, out IntPtr packet_out);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetMatrix(IntPtr packet, out SerializedProto serializedProto);
public static extern MpReturnCode mp__MakeColMajorMatrixPacket_At__Pf_i_i_ll(float[] data, int rows, int cols, long timestampMicrosec, out IntPtr packet_out);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern void mp_api_Matrix__delete(NativeMatrix matrix);
#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,44 @@ public void CreateIntAt_ShouldReturnNewIntPacket(int value)
}
#endregion

#region Matrix
[Test]
public void CreateColMajorMatrix_ShouldReturnNewMatrixPacket()
{
var value = new Matrix(new float[] { 1, 2, 3, 4, 5, 6 }, 2, 3);
using var packet = Packet.CreateColMajorMatrix(value);

Assert.DoesNotThrow(packet.Validate);

var result = packet.Get();
Assert.AreEqual(value.data, result.data);
Assert.AreEqual(value.rows, result.rows);
Assert.AreEqual(value.cols, result.cols);
Assert.AreEqual(value.layout, result.layout);

using var unsetTimestamp = Timestamp.Unset();
Assert.AreEqual(unsetTimestamp.Microseconds(), packet.TimestampMicroseconds());
}

[Test]
public void CreateColMajorMatrixAt_ShouldReturnNewMatrixPacket()
{
var value = new Matrix(new float[] { 1, 2, 3, 4, 5, 6 }, 2, 3);
var timestamp = 1;
using var packet = Packet.CreateColMajorMatrixAt(value, timestamp);

Assert.DoesNotThrow(packet.Validate);

var result = packet.Get();
Assert.AreEqual(value.data, result.data);
Assert.AreEqual(value.rows, result.rows);
Assert.AreEqual(value.cols, result.cols);
Assert.AreEqual(value.layout, result.layout);

Assert.AreEqual(timestamp, packet.TimestampMicroseconds());
}
#endregion

#region Proto
[Test]
public void CreateProto_ShouldReturnNewProtoPacket()
Expand Down
2 changes: 1 addition & 1 deletion mediapipe_api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ cc_library(
"//mediapipe_api/framework/formats:image",
"//mediapipe_api/framework/formats:image_frame",
"//mediapipe_api/framework/formats:landmark",
"//mediapipe_api/framework/formats:matrix_data",
"//mediapipe_api/framework/formats:matrix",
"//mediapipe_api/framework/formats:rect",
"//mediapipe_api/framework/port:logging",
"//mediapipe_api/tasks/c/components/containers:classification_result",
Expand Down
6 changes: 3 additions & 3 deletions mediapipe_api/framework/formats/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ cc_library(
)

cc_library(
name = "matrix_data",
srcs = ["matrix_data.cc"],
hdrs = ["matrix_data.h"],
name = "matrix",
srcs = ["matrix.cc"],
hdrs = ["matrix.h"],
deps = [
"//mediapipe_api:common",
"//mediapipe_api/external/absl:status",
Expand Down
60 changes: 60 additions & 0 deletions mediapipe_api/framework/formats/matrix.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) 2023 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#include "mediapipe_api/framework/formats/matrix.h"


MpReturnCode mp__MakeColMajorMatrixPacket__Pf_i_i(float* pcm_data, int rows, int cols, mediapipe::Packet** packet_out) {
TRY
Eigen::Map<Eigen::MatrixXf> m(pcm_data, rows, cols);

*packet_out = new mediapipe::Packet{mediapipe::MakePacket<mediapipe::Matrix>(m)};
RETURN_CODE(MpReturnCode::Success);
CATCH_EXCEPTION
}

MpReturnCode mp__MakeColMajorMatrixPacket_At__Pf_i_i_ll(float* pcm_data, int rows, int cols, int64 timestamp_microsec, mediapipe::Packet** packet_out) {
TRY
Eigen::Map<Eigen::MatrixXf> m(pcm_data, rows, cols);

*packet_out = new mediapipe::Packet{mediapipe::MakePacket<mediapipe::Matrix>(m).At(mediapipe::Timestamp(timestamp_microsec))};
RETURN_CODE(MpReturnCode::Success);
CATCH_EXCEPTION
}

MpReturnCode mp_Packet__GetMpMatrix(mediapipe::Packet* packet, mp_api::Matrix* value_out) {
TRY
auto matrix = packet->Get<mediapipe::Matrix>();
auto rows = matrix.rows();
auto cols = matrix.cols();
auto data = matrix.data();
auto len = rows * cols;

value_out->rows = rows;
value_out->cols = cols;
if (matrix.IsRowMajor) {
value_out->layout = mp_api::rowMajor;
} else {
value_out->layout = mp_api::colMajor;
}
value_out->data = new float[len];
memcpy(value_out->data, data, len * sizeof(float));

RETURN_CODE(MpReturnCode::Success);
CATCH_EXCEPTION
}

MpReturnCode mp_Packet__ValidateAsMatrix(mediapipe::Packet* packet, absl::Status** status_out) {
TRY
*status_out = new absl::Status{packet->ValidateAsType<mediapipe::Matrix>()};
RETURN_CODE(MpReturnCode::Success);
CATCH_EXCEPTION
}


void mp_api_Matrix__delete(mp_api::Matrix matrix) {
delete[] matrix.data;
}
Loading

0 comments on commit af74bb2

Please sign in to comment.