Skip to content

Commit

Permalink
feat: marshal ClassificationResult (#1090)
Browse files Browse the repository at this point in the history
  • Loading branch information
homuler committed Jan 2, 2024
1 parent ff0bd1c commit 36875b2
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2023 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System;
using System.Runtime.InteropServices;

namespace Mediapipe
{
[StructLayout(LayoutKind.Sequential)]
internal readonly struct NativeClassifications
{
private readonly IntPtr _categories;
public readonly uint categoriesCount;
public readonly int headIndex;
private readonly IntPtr _headName;

public ReadOnlySpan<NativeCategory> categories
{
get
{
unsafe
{
return new ReadOnlySpan<NativeCategory>((NativeCategory*)_categories, (int)categoriesCount);
}
}
}

public string headName => Marshal.PtrToStringAnsi(_headName);
}

[StructLayout(LayoutKind.Sequential)]
internal readonly struct NativeClassificationResult
{
private readonly IntPtr _classifications;
public readonly uint classificationsCount;
public readonly long timestampMs;
public readonly bool hasTimestampMs;

public ReadOnlySpan<NativeClassifications> classifications
{
get
{
unsafe
{
return new ReadOnlySpan<NativeClassifications>((NativeClassifications*)_classifications, (int)classificationsCount);
}
}
}

public void Dispose()
{
UnsafeNativeMethods.mp_tasks_c_components_containers_CppCloseClassificationResult(this);
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ namespace Mediapipe
{
internal static partial class UnsafeNativeMethods
{
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetClassificationsVector(IntPtr packet, out NativeClassificationResult value);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern void mp_tasks_c_components_containers_CppCloseClassificationResult(NativeClassificationResult data);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetDetectionResult(IntPtr packet, out NativeDetectionResult value);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Mediapipe.Tasks.Components.Containers
/// The array of predicted categories, usually sorted by descending scores,
/// e.g. from high to low probability.
/// </summary>
public readonly IReadOnlyList<Category> categories;
public readonly List<Category> categories;
/// <summary>
/// The index of the classifier head (i.e. output tensor) these categories
/// refer to. This is useful for multi-head models.
Expand All @@ -31,7 +31,7 @@ namespace Mediapipe.Tasks.Components.Containers
/// </summary>
public readonly string headName;

internal Classifications(IReadOnlyList<Category> categories, int headIndex, string headName)
internal Classifications(List<Category> categories, int headIndex, string headName)
{
this.categories = categories;
this.headIndex = headIndex;
Expand All @@ -58,6 +58,17 @@ public static Classifications CreateFrom(ClassificationList proto, int headIndex
return new Classifications(categories, headIndex, headName);
}

internal static void Copy(NativeClassifications source, ref Classifications destination)
{
var categories = destination.categories ?? new List<Category>((int)source.categoriesCount);
categories.Clear();
foreach (var nativeCategory in source.categories)
{
categories.Add(new Category(nativeCategory));
}
destination = new Classifications(categories, source.headIndex, source.headName);
}

public override string ToString()
=> $"{{ \"categories\": {Util.Format(categories)}, \"headIndex\": {headIndex}, \"headName\": {Util.Format(headName)} }}";
}
Expand All @@ -70,7 +81,7 @@ public override string ToString()
/// <summary>
/// The classification results for each head of the model.
/// </summary>
public readonly IReadOnlyList<Classifications> classifications;
public readonly List<Classifications> classifications;

/// <summary>
/// The optional timestamp (in milliseconds) of the start of the chunk of data
Expand All @@ -83,7 +94,7 @@ public override string ToString()
/// </summary>
public readonly long? timestampMs;

internal ClassificationResult(IReadOnlyList<Classifications> classifications, long? timestampMs)
internal ClassificationResult(List<Classifications> classifications, long? timestampMs)
{
this.classifications = classifications;
this.timestampMs = timestampMs;
Expand All @@ -101,6 +112,22 @@ public static ClassificationResult CreateFrom(Proto.ClassificationResult proto)
#pragma warning restore IDE0004 // for Unity 2020.3.x
}

internal static void Copy(NativeClassificationResult source, ref ClassificationResult destination)
{
var classificationsList = destination.classifications ?? new List<Classifications>((int)source.classificationsCount);
classificationsList.ResizeTo((int)source.classificationsCount);

var i = 0;
foreach (var nativeClassifications in source.classifications)
{
var classifications = classificationsList[i];
Classifications.Copy(nativeClassifications, ref classifications);
classificationsList[i++] = classifications;
}

destination = new ClassificationResult(classificationsList, source.hasTimestampMs ? (long?)source.timestampMs : null);
}

public override string ToString() => $"{{ \"classifications\": {Util.Format(classifications)}, \"timestampMs\": {Util.Format(timestampMs)} }}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ namespace Mediapipe.Tasks.Components.Containers
{
public static class PacketExtension
{
public static void GetClassificationsVector(this Packet packet, List<Classifications> outs)
{
UnsafeNativeMethods.mp_Packet__GetClassificationsVector(packet.mpPtr, out var classificationResult).Assert();
var tmp = new ClassificationResult(outs, null);
ClassificationResult.Copy(classificationResult, ref tmp);
classificationResult.Dispose();
}

public static void GetDetectionResult(this Packet packet, ref DetectionResult value)
{
UnsafeNativeMethods.mp_Packet__GetDetectionResult(packet.mpPtr, out var detectionResult).Assert();
Expand Down
1 change: 1 addition & 0 deletions mediapipe_api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ cc_library(
"//mediapipe_api/framework/formats:matrix_data",
"//mediapipe_api/framework/formats:rect",
"//mediapipe_api/framework/port:logging",
"//mediapipe_api/tasks/c/components/containers:classification_result",
"//mediapipe_api/tasks/c/components/containers:detection_result",
"//mediapipe_api/tasks/c/components/containers:landmark",
"//mediapipe_api/tasks/cc/vision/face_geometry/proto:face_geometry",
Expand Down
15 changes: 15 additions & 0 deletions mediapipe_api/tasks/c/components/containers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ package(
default_visibility = ["//visibility:public"],
)

cc_library(
name = "classification_result",
srcs = ["classification_result.cc"],
hdrs = ["classification_result.h"],
deps = [
"//mediapipe_api:common",
"@com_google_mediapipe//mediapipe/framework:packet",
"@com_google_mediapipe//mediapipe/framework/formats:classification_cc_proto",
"@com_google_mediapipe//mediapipe/tasks/c/components/containers:classification_result",
"@com_google_mediapipe//mediapipe/tasks/c/components/containers:classification_result_converter",
"@com_google_mediapipe//mediapipe/tasks/cc/components/containers:classification_result",
],
alwayslink = True,
)

cc_library(
name = "detection_result",
srcs = ["detection_result.cc"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "mediapipe_api/tasks/c/components/containers/classification_result.h"

MpReturnCode mp_Packet__GetClassificationsVector(mediapipe::Packet* packet, ClassificationResult* value_out) {
TRY_ALL
// get std::vector<ClassificationList> and convert it to ClassificationResult
auto proto_vec = packet->Get<std::vector<mediapipe::ClassificationList>>();
auto vec_size = proto_vec.size();

auto classifications_vec = std::vector<mediapipe::tasks::components::containers::Classifications>(vec_size);
for (auto i = 0; i < vec_size; ++i) {
auto classifications = mediapipe::tasks::components::containers::ConvertToClassifications(proto_vec[i]);
classifications_vec.push_back(classifications);
}

mediapipe::tasks::components::containers::ClassificationResult result;
result.classifications = std::move(classifications_vec);

mediapipe::tasks::c::components::containers::CppConvertToClassificationResult(result, value_out);
RETURN_CODE(MpReturnCode::Success);
CATCH_ALL
}

void mp_tasks_c_components_containers_CppCloseClassificationResult(ClassificationResult data) {
mediapipe::tasks::c::components::containers::CppCloseClassificationResult(&data);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) 2023 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#ifndef MEDIAPIPE_API_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
#define MEDIAPIPE_API_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_

#include "mediapipe/framework/packet.h"
#include "mediapipe/tasks/c/components/containers/classification_result.h"
#include "mediapipe/tasks/c/components/containers/classification_result_converter.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe_api/common.h"

extern "C" {

MP_CAPI(MpReturnCode) mp_Packet__GetClassificationsVector(mediapipe::Packet* packet, ClassificationResult* value_out);
MP_CAPI(void) mp_tasks_c_components_containers_CppCloseClassificationResult(ClassificationResult data);

} // extern "C"

#endif // MEDIAPIPE_API_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_

0 comments on commit 36875b2

Please sign in to comment.