Skip to content

Commit

Permalink
feat(sample): implement MIN_(DETECTION|TRACKING)_CONFIDENCE (#483)
Browse files Browse the repository at this point in the history
* feat: port TensorsToDetectionsCalculatorOptions

* feat: Parse CalculatorGraphConfig with Extensions

* feat(sample): FaceDetectionConfig#MIN_DETECTION_CONFIDENCE

* fix: uninstall proto.cs

* feat: port ThresholdingCalculatorOptions

* remove import_prefix

* fix compilation errors

* feat(sample): MIN_DETECTION_CONFIDENCE /MIN_TRACKING_CONFIDENCE
  • Loading branch information
homuler committed Mar 9, 2022
1 parent c7a33ba commit 13c2a38
Show file tree
Hide file tree
Showing 50 changed files with 612 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

namespace Mediapipe.Unity
{
#pragma warning disable IDE0065
using Color = UnityEngine.Color;
#pragma warning restore IDE0065

public class TextureFrame
{
public class ReleaseEvent : UnityEvent<TextureFrame> { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using UnityEngine.Events;

using Google.Protobuf;

namespace Mediapipe.Unity.FaceDetection
{
public class FaceDetectionGraph : GraphRunner
Expand All @@ -18,6 +22,14 @@ public enum ModelType
FullRangeSparse = 1,
}
public ModelType modelType = ModelType.ShortRange;

private float _minDetectionConfidence = 0.5f;
public float minDetectionConfidence
{
get => _minDetectionConfidence;
set => _minDetectionConfidence = Mathf.Clamp01(value);
}

#pragma warning disable IDE1006
public UnityEvent<List<Detection>> OnFaceDetectionsOutput = new UnityEvent<List<Detection>>();
#pragma warning restore IDE1006
Expand Down Expand Up @@ -94,7 +106,27 @@ protected override Status ConfigureCalculatorGraph(CalculatorGraphConfig config)
{
_faceDetectionsStream = new OutputStream<DetectionVectorPacket, List<Detection>>(calculatorGraph, _FaceDetectionsStreamName, true);
}
return calculatorGraph.Initialize(config);

using (var validatedGraphConfig = new ValidatedGraphConfig())
{
var status = validatedGraphConfig.Initialize(config);

if (!status.Ok()) { return status; }

var extensionRegistry = new ExtensionRegistry() { TensorsToDetectionsCalculatorOptions.Extensions.Ext };
var cannonicalizedConfig = validatedGraphConfig.Config(extensionRegistry);
var tensorsToDetectionsCalculators = cannonicalizedConfig.Node.Where((node) => node.Calculator == "TensorsToDetectionsCalculator").ToList();

foreach (var calculator in tensorsToDetectionsCalculators)
{
if (calculator.Options.HasExtension(TensorsToDetectionsCalculatorOptions.Extensions.Ext))
{
var options = calculator.Options.GetExtension(TensorsToDetectionsCalculatorOptions.Extensions.Ext);
options.MinScoreThresh = minDetectionConfidence;
}
}
return calculatorGraph.Initialize(cannonicalizedConfig);
}
}

private SidePacket BuildSidePacket(ImageSource imageSource)
Expand Down
51 changes: 50 additions & 1 deletion Assets/Mediapipe/Samples/Scenes/Face Mesh/FaceMeshGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,33 @@

using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using UnityEngine.Events;

using Google.Protobuf;

namespace Mediapipe.Unity.FaceMesh
{
public class FaceMeshGraph : GraphRunner
{
public int maxNumFaces = 1;
public bool refineLandmarks = true;

private float _minDetectionConfidence = 0.5f;
public float minDetectionConfidence
{
get => _minDetectionConfidence;
set => _minDetectionConfidence = Mathf.Clamp01(value);
}

private float _minTrackingConfidence = 0.5f;
public float minTrackingConfidence
{
get => _minTrackingConfidence;
set => _minTrackingConfidence = Mathf.Clamp01(value);
}

#pragma warning disable IDE1006 // UnityEvent is PascalCase
public UnityEvent<List<Detection>> OnFaceDetectionsOutput = new UnityEvent<List<Detection>>();
public UnityEvent<List<NormalizedLandmarkList>> OnMultiFaceLandmarksOutput = new UnityEvent<List<NormalizedLandmarkList>>();
Expand Down Expand Up @@ -163,7 +182,37 @@ protected override Status ConfigureCalculatorGraph(CalculatorGraphConfig config)
_faceRectsFromLandmarksStream = new OutputStream<NormalizedRectVectorPacket, List<NormalizedRect>>(calculatorGraph, _FaceRectsFromLandmarksStreamName, true);
_faceRectsFromDetectionsStream = new OutputStream<NormalizedRectVectorPacket, List<NormalizedRect>>(calculatorGraph, _FaceRectsFromDetectionsStreamName, true);
}
return calculatorGraph.Initialize(config);

using (var validatedGraphConfig = new ValidatedGraphConfig())
{
var status = validatedGraphConfig.Initialize(config);

if (!status.Ok()) { return status; }

var extensionRegistry = new ExtensionRegistry() { TensorsToDetectionsCalculatorOptions.Extensions.Ext, ThresholdingCalculatorOptions.Extensions.Ext };
var cannonicalizedConfig = validatedGraphConfig.Config(extensionRegistry);
var tensorsToDetectionsCalculators = cannonicalizedConfig.Node.Where((node) => node.Calculator == "TensorsToDetectionsCalculator").ToList();
var thresholdingCalculators = cannonicalizedConfig.Node.Where((node) => node.Calculator == "ThresholdingCalculator").ToList();

foreach (var calculator in tensorsToDetectionsCalculators)
{
if (calculator.Options.HasExtension(TensorsToDetectionsCalculatorOptions.Extensions.Ext))
{
var options = calculator.Options.GetExtension(TensorsToDetectionsCalculatorOptions.Extensions.Ext);
options.MinScoreThresh = minDetectionConfidence;
}
}

foreach (var calculator in thresholdingCalculators)
{
if (calculator.Options.HasExtension(ThresholdingCalculatorOptions.Extensions.Ext))
{
var options = calculator.Options.GetExtension(ThresholdingCalculatorOptions.Extensions.Ext);
options.Threshold = minTrackingConfidence;
}
}
return calculatorGraph.Initialize(cannonicalizedConfig);
}
}

protected override IList<WaitForResult> RequestDependentAssets()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using UnityEngine.Events;

using Google.Protobuf;

namespace Mediapipe.Unity.HandTracking
{
public class HandTrackingGraph : GraphRunner
Expand All @@ -21,6 +25,20 @@ public enum ModelComplexity
public ModelComplexity modelComplexity = ModelComplexity.Full;
public int maxNumHands = 2;

private float _minDetectionConfidence = 0.5f;
public float minDetectionConfidence
{
get => _minDetectionConfidence;
set => _minDetectionConfidence = Mathf.Clamp01(value);
}

private float _minTrackingConfidence = 0.5f;
public float minTrackingConfidence
{
get => _minTrackingConfidence;
set => _minTrackingConfidence = Mathf.Clamp01(value);
}

#pragma warning disable IDE1006 // UnityEvent is PascalCase
public UnityEvent<List<Detection>> OnPalmDetectectionsOutput = new UnityEvent<List<Detection>>();
public UnityEvent<List<NormalizedRect>> OnHandRectsFromPalmDetectionsOutput = new UnityEvent<List<NormalizedRect>>();
Expand Down Expand Up @@ -231,7 +249,37 @@ protected override Status ConfigureCalculatorGraph(CalculatorGraphConfig config)
_handRectsFromLandmarksStream = new OutputStream<NormalizedRectVectorPacket, List<NormalizedRect>>(calculatorGraph, _HandRectsFromLandmarksStreamName, true);
_handednessStream = new OutputStream<ClassificationListVectorPacket, List<ClassificationList>>(calculatorGraph, _HandednessStreamName, true);
}
return calculatorGraph.Initialize(config);

using (var validatedGraphConfig = new ValidatedGraphConfig())
{
var status = validatedGraphConfig.Initialize(config);

if (!status.Ok()) { return status; }

var extensionRegistry = new ExtensionRegistry() { TensorsToDetectionsCalculatorOptions.Extensions.Ext, ThresholdingCalculatorOptions.Extensions.Ext };
var cannonicalizedConfig = validatedGraphConfig.Config(extensionRegistry);
var tensorsToDetectionsCalculators = cannonicalizedConfig.Node.Where((node) => node.Calculator == "TensorsToDetectionsCalculator").ToList();
var thresholdingCalculators = cannonicalizedConfig.Node.Where((node) => node.Calculator == "ThresholdingCalculator").ToList();

foreach (var calculator in tensorsToDetectionsCalculators)
{
if (calculator.Options.HasExtension(TensorsToDetectionsCalculatorOptions.Extensions.Ext))
{
var options = calculator.Options.GetExtension(TensorsToDetectionsCalculatorOptions.Extensions.Ext);
options.MinScoreThresh = minDetectionConfidence;
}
}

foreach (var calculator in thresholdingCalculators)
{
if (calculator.Options.HasExtension(ThresholdingCalculatorOptions.Extensions.Ext))
{
var options = calculator.Options.GetExtension(ThresholdingCalculatorOptions.Extensions.Ext);
options.Threshold = minTrackingConfidence;
}
}
return calculatorGraph.Initialize(cannonicalizedConfig);
}
}

private WaitForResult WaitForHandLandmarkModel()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using UnityEngine;
using UnityEngine.Events;

using Google.Protobuf;

namespace Mediapipe.Unity.Holistic
{
public class HolisticTrackingGraph : GraphRunner
Expand All @@ -23,6 +28,20 @@ public enum ModelComplexity
public ModelComplexity modelComplexity = ModelComplexity.Lite;
public bool smoothLandmarks = true;

private float _minDetectionConfidence = 0.5f;
public float minDetectionConfidence
{
get => _minDetectionConfidence;
set => _minDetectionConfidence = Mathf.Clamp01(value);
}

private float _minTrackingConfidence = 0.5f;
public float minTrackingConfidence
{
get => _minTrackingConfidence;
set => _minTrackingConfidence = Mathf.Clamp01(value);
}

#pragma warning disable IDE1006 // UnityEvent is PascalCase
public UnityEvent<Detection> OnPoseDetectionOutput = new UnityEvent<Detection>();
public UnityEvent<NormalizedLandmarkList> OnPoseLandmarksOutput = new UnityEvent<NormalizedLandmarkList>();
Expand Down Expand Up @@ -275,7 +294,41 @@ protected override Status ConfigureCalculatorGraph(CalculatorGraphConfig config)
_poseWorldLandmarksStream = new OutputStream<LandmarkListPacket, LandmarkList>(calculatorGraph, _PoseWorldLandmarksStreamName, true);
_poseRoiStream = new OutputStream<NormalizedRectPacket, NormalizedRect>(calculatorGraph, _PoseRoiStreamName, true);
}
return calculatorGraph.Initialize(config);

using (var validatedGraphConfig = new ValidatedGraphConfig())
{
var status = validatedGraphConfig.Initialize(config);

if (!status.Ok()) { return status; }

var extensionRegistry = new ExtensionRegistry() { TensorsToDetectionsCalculatorOptions.Extensions.Ext, ThresholdingCalculatorOptions.Extensions.Ext };
var cannonicalizedConfig = validatedGraphConfig.Config(extensionRegistry);

var poseDetectionCalculatorPattern = new Regex("__posedetection[a-z]+__TensorsToDetectionsCalculator$");
var tensorsToDetectionsCalculators = cannonicalizedConfig.Node.Where((node) => poseDetectionCalculatorPattern.Match(node.Name).Success).ToList();

var poseTrackingCalculatorPattern = new Regex("tensorstoposelandmarksandsegmentation__ThresholdingCalculator$");
var thresholdingCalculators = cannonicalizedConfig.Node.Where((node) => poseTrackingCalculatorPattern.Match(node.Name).Success).ToList();

foreach (var calculator in tensorsToDetectionsCalculators)
{
if (calculator.Options.HasExtension(TensorsToDetectionsCalculatorOptions.Extensions.Ext))
{
var options = calculator.Options.GetExtension(TensorsToDetectionsCalculatorOptions.Extensions.Ext);
options.MinScoreThresh = minDetectionConfidence;
}
}

foreach (var calculator in thresholdingCalculators)
{
if (calculator.Options.HasExtension(ThresholdingCalculatorOptions.Extensions.Ext))
{
var options = calculator.Options.GetExtension(ThresholdingCalculatorOptions.Extensions.Ext);
options.Threshold = minTrackingConfidence;
}
}
return calculatorGraph.Initialize(cannonicalizedConfig);
}
}

private SidePacket BuildSidePacket(ImageSource imageSource)
Expand Down
51 changes: 50 additions & 1 deletion Assets/Mediapipe/Samples/Scenes/Objectron/ObjectronGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using UnityEngine.Events;

using Google.Protobuf;

namespace Mediapipe.Unity.Objectron
{
public class ObjectronGraph : GraphRunner
Expand All @@ -25,6 +28,20 @@ public enum Category
public Category category;
public int maxNumObjects = 5;

private float _minDetectionConfidence = 0.5f;
public float minDetectionConfidence
{
get => _minDetectionConfidence;
set => _minDetectionConfidence = Mathf.Clamp01(value);
}

private float _minTrackingConfidence = 0.99f;
public float minTrackingConfidence
{
get => _minTrackingConfidence;
set => _minTrackingConfidence = Mathf.Clamp01(value);
}

public Vector2 focalLength
{
get
Expand Down Expand Up @@ -174,7 +191,39 @@ protected override Status ConfigureCalculatorGraph(CalculatorGraphConfig config)
_multiBoxRectsStream = new OutputStream<NormalizedRectVectorPacket, List<NormalizedRect>>(calculatorGraph, _MultiBoxRectsStreamName, true);
_multiBoxLandmarksStream = new OutputStream<NormalizedLandmarkListVectorPacket, List<NormalizedLandmarkList>>(calculatorGraph, _MultiBoxLandmarksStreamName, true);
}
return calculatorGraph.Initialize(config);

using (var validatedGraphConfig = new ValidatedGraphConfig())
{
var status = validatedGraphConfig.Initialize(config);

if (!status.Ok()) { return status; }

var extensionRegistry = new ExtensionRegistry() { TensorsToDetectionsCalculatorOptions.Extensions.Ext, ThresholdingCalculatorOptions.Extensions.Ext };
var cannonicalizedConfig = validatedGraphConfig.Config(extensionRegistry);
var tensorsToDetectionsCalculators = cannonicalizedConfig.Node.Where((node) => node.Calculator == "TensorsToDetectionsCalculator").ToList();
var thresholdingCalculators = cannonicalizedConfig.Node.Where((node) => node.Calculator == "ThresholdingCalculator").ToList();

Debug.Log(tensorsToDetectionsCalculators.Count);
Debug.Log(thresholdingCalculators.Count);
foreach (var calculator in tensorsToDetectionsCalculators)
{
if (calculator.Options.HasExtension(TensorsToDetectionsCalculatorOptions.Extensions.Ext))
{
var options = calculator.Options.GetExtension(TensorsToDetectionsCalculatorOptions.Extensions.Ext);
options.MinScoreThresh = minDetectionConfidence;
}
}

foreach (var calculator in thresholdingCalculators)
{
if (calculator.Options.HasExtension(ThresholdingCalculatorOptions.Extensions.Ext))
{
var options = calculator.Options.GetExtension(ThresholdingCalculatorOptions.Extensions.Ext);
options.Threshold = minTrackingConfidence;
}
}
return calculatorGraph.Initialize(cannonicalizedConfig);
}
}

private SidePacket BuildSidePacket(ImageSource imageSource)
Expand Down
Loading

0 comments on commit 13c2a38

Please sign in to comment.