Skip to content

Commit

Permalink
Add data filtering task with kNN (#240)
Browse files Browse the repository at this point in the history
* Add interface for data cleaning

* Add kNN and tests

* Add cleaning task
  • Loading branch information
anthoak13 committed Jun 21, 2024
1 parent c926e25 commit 2aa70c6
Show file tree
Hide file tree
Showing 12 changed files with 303 additions and 3 deletions.
65 changes: 65 additions & 0 deletions AtReconstruction/AtDataCleaningTask.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "AtDataCleaningTask.h"
// IWYU pragma: no_include <ext/alloc_traits.h>
#include "AtEvent.h"
#include "AtHit.h"

#include <FairLogger.h>
#include <FairRootManager.h>
#include <FairRun.h>
#include <FairRuntimeDb.h>
#include <FairTask.h>

#include <Math/Point3D.h>
#include <Math/Point3Dfwd.h>
#include <TClonesArray.h>
#include <TObject.h>

#include <utility>
#include <vector>

class AtDigiPar;
using XYZPoint = ROOT::Math::XYZPoint;

ClassImp(AtDataCleaningTask);

AtDataCleaningTask::AtDataCleaningTask(DataCleaner &&cleaner)
: fCleaner(std::move(cleaner)), fOutputEventArray(TClonesArray("AtEvent", 1)), fInputEventArray(nullptr)
{
}

InitStatus AtDataCleaningTask::Init()
{
if (FairRootManager::Instance() == nullptr) {
LOG(fatal) << "Cannot find RootManager!";
return kFATAL;
}

fInputEventArray = dynamic_cast<TClonesArray *>(FairRootManager::Instance()->GetObject(fInputBranchName.data()));
if (fInputEventArray == nullptr) {
LOG(fatal) << "SpaceChargeTask: Cannot find AtEvent array!";
return kFATAL;
}

FairRootManager::Instance()->Register(fOuputBranchName.data(), "AtTpc", &fOutputEventArray, fIsPersistent);

return kSUCCESS;
}

void AtDataCleaningTask::SetParContainers() {}

void AtDataCleaningTask::Exec(Option_t *opt)
{
fOutputEventArray.Clear("C");

if (fInputEventArray->GetEntriesFast() == 0)
return;

auto inputEvent = dynamic_cast<AtEvent *>(fInputEventArray->At(0));
auto outputEvent = dynamic_cast<AtEvent *>(fOutputEventArray.ConstructedAt(0));
*outputEvent = *inputEvent;
outputEvent->ClearHits();

auto hits = fCleaner->CleanData(inputEvent->GetHits());
for (auto &hit : hits)
outputEvent->AddHit(std::move(hit));
}
52 changes: 52 additions & 0 deletions AtReconstruction/AtDataCleaningTask.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/* Task for correcting SpaceCharge effects by
* applying some SpaceChargeModel and
* modifying the position values of hits in AtEvent
*
*/
#ifndef ATDATACLEANINGTASK_H
#define ATDATACLEANINGTASK_H

#include "AtDataCleaner.h" // IWYU pragma: keep

#include <FairTask.h>

#include <Rtypes.h>
#include <TClonesArray.h>

#include <memory>
#include <string>

class TBuffer;
class TClass;
class TMemberInspector;

using DataCleaner = std::unique_ptr<AtTools::DataCleaning::AtDataCleaner>;

class AtDataCleaningTask : public FairTask {

private:
std::string fInputBranchName = "AtEventH";
std::string fOuputBranchName = "AtEventCleaned";
Bool_t fIsPersistent = true;

TClonesArray *fInputEventArray;
TClonesArray fOutputEventArray;
DataCleaner fCleaner;

public:
AtDataCleaningTask(DataCleaner &&cleaner);
virtual ~AtDataCleaningTask() = default;

void SetInputBranch(std::string branchName) { fInputBranchName = branchName; }
void SetOutputBranch(std::string branchName) { fOuputBranchName = branchName; }
void SetPersistence(Bool_t value) { fIsPersistent = value; }

virtual InitStatus Init() override;
virtual void SetParContainers() override;
virtual void Exec(Option_t *opt) override;
// virtual void FinishEvent() override;

ClassDefOverride(AtDataCleaningTask, 1);
};

#endif //_ATSPACECHARGETASK_H_
1 change: 1 addition & 0 deletions AtReconstruction/AtReconstructionLinkDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#pragma link C++ class AtSampleConsensusTask + ;
#pragma link C++ class AtDataReductionTask + ;
#pragma link C++ class AtSpaceChargeCorrectionTask + ;
#pragma link C++ class AtDataCleaningTask + ;
#pragma link C++ class AtFilterTask + ;
#pragma link C++ class AtHDF5WriteTask + ;
#pragma link C++ class AtHDF5ReadTask + ;
Expand Down
1 change: 1 addition & 0 deletions AtReconstruction/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ set(SRCS
AtAuxFilterTask.cxx
AtDataReductionTask.cxx
AtSpaceChargeCorrectionTask.cxx
AtDataCleaningTask.cxx
AtHDF5ReadTask.cxx
AtHDF5WriteTask.cxx
AtMacroTask.cxx
Expand Down
4 changes: 4 additions & 0 deletions AtTools/AtToolsLinkDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#pragma link C++ namespace ElectronicResponse;
#pragma link C++ namespace tk;
#pragma link C++ namespace AtTools::Kinematics;
#pragma link C++ namespace AtTools::DataCleaning;

#pragma link C++ class AtTools::AtELossManager + ;
#pragma link C++ class AtTools::AtParsers + ;
Expand Down Expand Up @@ -68,4 +69,7 @@
#pragma link C++ function AtTools::Kinematics::AtoE;
#pragma link C++ function AtTools::Kinematics::EtoA;

#pragma link C++ class AtTools::DataCleaning::AtkNN + ;
#pragma link C++ class AtTools::DataCleaning::AtDataCleaner + ;

#endif
13 changes: 13 additions & 0 deletions AtTools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ set(SRCS
ElectronicResponse/AtElectronicResponse.cxx
ElectronicResponse/AtROOTresponse.cxx
AtHitSampling/AtWeightedGaussianTrunc.cxx

DataCleaning/AtDataCleaner.cxx
DataCleaning/AtkNN.cxx
)

Set(DEPENDENCIES
Expand All @@ -59,8 +62,18 @@ endif()
Set(INCLUDE_DIR
${CMAKE_SOURCE_DIR}/AtTools/AtHitSampling
${CMAKE_SOURCE_DIR}/AtTools/ElectronicResponse
${CMAKE_SOURCE_DIR}/AtTools/DataCleaning
)

set(TEST_SRCS
DataCleaning/AtkNNTest.cxx
)

attpcroot_generate_tests(${LIBRARY_NAME}Tests
SRCS ${TEST_SRCS}
DEPS ${LIBRARY_NAME}
)

generate_target_and_root_library(${LIBRARY_NAME}
LINKDEF ${LINKDEF}
SRCS ${SRCS}
Expand Down
1 change: 1 addition & 0 deletions AtTools/DataCleaning/AtDataCleaner.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "AtDataCleaner.h"
35 changes: 35 additions & 0 deletions AtTools/DataCleaning/AtDataCleaner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef ATKNN_H
#define ATKNN_H

#include "AtHit.h"

#include <memory>
#include <vector>
class AtHit;

namespace AtTools {

namespace DataCleaning {

using HitCloud = std::vector<std::unique_ptr<AtHit>>;

/**
* @brief Interface for data cleaning algorithms.
* They take in a hit cloud and output a hit cloud.
*/
class AtDataCleaner {
public:
virtual ~AtDataCleaner() = default;
/**
* @brief Clean the data.
* @param hits The input hit cloud.
* @return The cleaned hit cloud.
*/
virtual HitCloud CleanData(const HitCloud &hits) = 0;
};

} // namespace DataCleaning

} // namespace AtTools

#endif // ATKNN_H
63 changes: 63 additions & 0 deletions AtTools/DataCleaning/AtkNN.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include "AtkNN.h"

namespace AtTools {
namespace DataCleaning {
HitCloud AtkNN::CleanData(const HitCloud &hits)
{
HitCloud ret;
for (const auto &hit : hits) {
if (kNN(hits, *hit))
ret.push_back(std::make_unique<AtHit>(*hit));
}
return ret;
}

/**
* @brief kNN algorithm to clean data.
*
* Returns true of k-nearest neighbors to hitRef are within a threshold distance.
*/
bool AtkNN::kNN(const std::vector<std::unique_ptr<AtHit>> &hits, AtHit &hitRef)
{
int k = fkNN;
std::vector<Double_t> distances;
distances.reserve(hits.size());

std::for_each(hits.begin(), hits.end(), [&distances, &hitRef](const std::unique_ptr<AtHit> &hit) {
auto dist = (hitRef.GetPosition() - hit->GetPosition()).Mag2();
if (dist > 0.01) // Ignore self
distances.push_back(std::sqrt(dist));
});

std::sort(distances.begin(), distances.end(), [](Double_t a, Double_t b) { return a < b; });

/*
for (auto i = 0; i < distances.size(); ++i)
std::cout << distances.at(i) << " ";
std::cout << "\n";
*/

Double_t mean = 0.0;
Double_t stdDev = 0.0;

if (k > hits.size())
k = hits.size();

// Compute mean distance of kNN
for (auto i = 0; i < k; ++i)
mean += distances.at(i);

mean /= k;

// Compute threshold
Double_t T = mean;
/*
std::cout << "For hit at " << hitRef.GetPosition() << " T: " << T << " fkNNDist: " << fkNNDist << " k: " << k
<< "\n";
*/

return T < fkNNDist;
}

} // namespace DataCleaning
} // namespace AtTools
25 changes: 25 additions & 0 deletions AtTools/DataCleaning/AtkNN.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "AtDataCleaner.h"

namespace AtTools {

namespace DataCleaning {

/**
* kNN algorithm as implemented in AtPRA class.
*
* Will reject any point whose average distance to its k nearest neighbors is greater than a threshold.
*/
class AtkNN : public AtDataCleaner {
protected:
int fkNN; //<! Number of nearest neighbors to consider
double fkNNDist; //<! Distance threshold for outlier rejection in kNN
public:
AtkNN(int kNN, double kNNDist) : fkNN(kNN), fkNNDist(kNNDist) {}
HitCloud CleanData(const HitCloud &hits) override;

bool kNN(const HitCloud &hits, AtHit &hitRef);
};

} // namespace DataCleaning

} // namespace AtTools
33 changes: 33 additions & 0 deletions AtTools/DataCleaning/AtkNNTest.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "AtkNN.h"

#include "AtHit.h"

#include <gtest/gtest.h>
using XYZPoint = ROOT::Math::XYZPoint;

TEST(AtkNNTest, AllOneAway)
{

// Create a sample HitCloud
AtTools::DataCleaning::HitCloud hits;
hits.push_back(std::make_unique<AtHit>(-1, XYZPoint(0, 0, 0), 0));
hits.push_back(std::make_unique<AtHit>(-1, XYZPoint(1, 0, 0), 0));
hits.push_back(std::make_unique<AtHit>(-1, XYZPoint(0, 1, 0), 0));

// Create an instance of AtkNN
AtTools::DataCleaning::AtkNN atkNN(1, 0.5);
AtTools::DataCleaning::HitCloud cleanedHits = atkNN.CleanData(hits);
EXPECT_EQ(0, cleanedHits.size());

atkNN = AtTools::DataCleaning::AtkNN(1, 1.1);
cleanedHits = atkNN.CleanData(hits);
EXPECT_EQ(3, cleanedHits.size());

atkNN = AtTools::DataCleaning::AtkNN(2, 1.1);
cleanedHits = atkNN.CleanData(hits);
EXPECT_EQ(1, cleanedHits.size());

atkNN = AtTools::DataCleaning::AtkNN(2, 1.5);
cleanedHits = atkNN.CleanData(hits);
EXPECT_EQ(3, cleanedHits.size());
}
13 changes: 10 additions & 3 deletions macro/tests/AT-TPC/run_unpack_attpc.C
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,22 @@ void run_unpack_attpc(int runNumber = 174)
psa->SetThreshold(threshold);

AtPSAtask *psaTask = new AtPSAtask(std::move(psa));
psaTask->SetInputBranch("AtRawEventFiltered");
psaTask->SetOutputBranch("AtEventFiltered");
psaTask->SetInputBranch("AtRawEvent");
psaTask->SetOutputBranch("AtEventH");
psaTask->SetPersistence(kTRUE);

auto cleaner = std::make_unique<AtTools::DataCleaning::AtkNN>(3, 20);
AtDataCleaningTask *cleaningTask = new AtDataCleaningTask(std::move(cleaner));
cleaningTask->SetInputBranch("AtEventH");
cleaningTask->SetOutputBranch("AtEventCleaned");
cleaningTask->SetPersistence(kTRUE);

// Add unpacker to the run
run->AddTask(unpackTask);
run->AddTask(reduceTask);
run->AddTask(filterTask);
// run->AddTask(filterTask);
run->AddTask(psaTask);
run->AddTask(cleaningTask);

std::cout << "***** Starting Init ******" << std::endl;
run->Init();
Expand Down

0 comments on commit 2aa70c6

Please sign in to comment.