Skip to content

Commit

Permalink
Write out the pseudolabel weights and a flag that indicates whether a…
Browse files Browse the repository at this point in the history
… sample has a ground truth label (0) or a pseudolabel (1).

PiperOrigin-RevId: 652155894
  • Loading branch information
raj-sinha committed Jul 14, 2024
1 parent 0fcfeda commit 84cbfe7
Show file tree
Hide file tree
Showing 11 changed files with 315 additions and 41 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

## [0.3.1] - 2024-07-13

* Now writes out the pseudolabel weights and a flag that indicates whether a sample has a ground truth label (0) or a pseudolabel (1).

## [0.3.0] - 2024-07-10

* Add the ability to use CSV files on GCS as data input/output/test sources.
Expand All @@ -45,7 +49,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

* Initial release

[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.0...HEAD
[Unreleased]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.1...HEAD
[0.3.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.3.0...v0.3.1
[0.3.0]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.2...v0.3.0
[0.2.2]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.1...v0.2.2
[0.2.1]: https://github.com/google-research/spade_anomaly_detection/compare/v0.2.0...v0.2.1
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"pyarrow==14.0.1",
"retry==0.9.2",
"scikit-learn==1.4.2",
"tensorflow",
"tensorflow==2.12.1",
"tensorflow-datasets==4.9.6",
"parameterized==0.8.1",
"pytest==7.1.2",
Expand Down
2 changes: 1 addition & 1 deletion spade_anomaly_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@

# A new PyPI release will be pushed every time `__version__` is increased.
# When changing this, also update the CHANGELOG.md.
__version__ = '0.3.0'
__version__ = '0.3.1'
43 changes: 39 additions & 4 deletions spade_anomaly_detection/csv_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from google.cloud import storage
import numpy as np
import pandas as pd
from spade_anomaly_detection import data_loader
from spade_anomaly_detection import parameters
import tensorflow as tf

Expand Down Expand Up @@ -489,13 +490,17 @@ def upload_dataframe_to_gcs(
batch: int,
features: np.ndarray,
labels: np.ndarray,
weights: Optional[np.ndarray] = None,
pseudolabel_flags: Optional[np.ndarray] = None,
) -> None:
"""Uploads the dataframe to BigQuery, create or replace table.
Args:
batch: The batch number of the pseudo-labeled data.
features: Numpy array of features.
labels: Numpy array of labels.
weights: Optional numpy array of weights.
pseudolabel_flags: Optional numpy array of pseudolabel flags.
Returns:
None.
Expand All @@ -515,15 +520,37 @@ def upload_dataframe_to_gcs(
'Data output GCS URI is not set in the runner parameters. Please set '
'the `data_output_gcs_uri` field in the runner parameters.'
)
combined_data = np.concatenate(
[features, labels.reshape(len(features), 1)], axis=1
)
combined_data = features

column_names = list(
self._last_read_metadata.column_names_info.column_names_dict.keys()
)

# If the weights are provided, add them to the column names and to the
# combined data.
if weights is not None:
column_names.append(data_loader.WEIGHT_COLUMN_NAME)
combined_data = np.concatenate(
[combined_data, weights.reshape(len(features), 1).astype(np.float64)],
axis=1,
)

# If the pseudolabel flags are provided, add them to the column names and
# to the combined data.
if pseudolabel_flags is not None:
column_names.append(data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME)
combined_data = np.concatenate(
[
combined_data,
pseudolabel_flags.reshape(len(features), 1).astype(np.int64),
],
axis=1,
)

# Make sure the label column is the last column.
# TODO(b/347332980): Add support for the pseudolabel flag.
combined_data = np.concatenate(
[combined_data, labels.reshape(len(features), 1)], axis=1
)
column_names.remove(self.runner_parameters.label_col_name)
column_names.append(self.runner_parameters.label_col_name)

Expand All @@ -536,6 +563,14 @@ def upload_dataframe_to_gcs(
complete_dataframe[self.runner_parameters.label_col_name].astype('bool')
)

# Adjust pseudolabel flag column type.
if pseudolabel_flags is not None:
complete_dataframe[data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME] = (
complete_dataframe[data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME].astype(
np.int64
)
)

output_path = os.path.join(
self.runner_parameters.data_output_gcs_uri,
f'pseudo_labeled_batch_{batch}.csv',
Expand Down
57 changes: 51 additions & 6 deletions spade_anomaly_detection/csv_data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,30 +385,75 @@ def test_upload_dataframe_to_gcs(self):
all_features = self.data_df[["x1", "x2"]].to_numpy()
all_labels = self.data_df["y"].to_numpy()
# Create 2 batches of features and labels.
# TODO(b/347332980): Update test when pseudolabel flag is added.
features1 = all_features[0:2]
labels1 = all_labels[0:2]
# Add weights and flags to the first batch. These are pseudolabeled samples.
weights1 = (
np.repeat([0.1], len(features1))
.reshape(len(features1), 1)
.astype(np.float64)
)
flags1 = (
np.repeat([1], len(features1))
.reshape(len(features1), 1)
.astype(np.int64)
)
# Add weights and flags to the first batch. These are ground truth samples.
features2 = all_features[2:]
labels2 = all_labels[2:]
# Upload batch 1.
weights2 = (
np.repeat([1.0], len(features2))
.reshape(len(features2), 1)
.astype(np.float64)
)
flags2 = (
np.repeat([0], len(features2))
.reshape(len(features2), 1)
.astype(np.int64)
) # Upload batch 1.
data_loader.upload_dataframe_to_gcs(
batch=1,
features=features1,
labels=labels1,
weights=weights1,
pseudolabel_flags=flags1,
)
# Upload batch 2.
data_loader.upload_dataframe_to_gcs(
batch=2,
features=features2,
labels=labels2,
weights=weights2,
pseudolabel_flags=flags2,
)
# Sorting means batch 1 file will be first.
files_list = sorted(tf.io.gfile.listdir(output_dir))
self.assertLen(files_list, 2)
expected_dfs = [
self.data_df.iloc[0:2].reset_index(drop=True),
self.data_df.iloc[2:].reset_index(drop=True),
]
col_names = ["x1", "x2", "alpha", "is_pseudolabel", "y"]
expected_df1 = pd.concat(
[
self.data_df.iloc[0:2, 0:-1].reset_index(drop=True),
pd.DataFrame(weights1, columns=["alpha"]),
pd.DataFrame(flags1, columns=["is_pseudolabel"]),
self.data_df.iloc[0:2, -1].reset_index(drop=True),
],
names=col_names,
ignore_index=True,
axis=1,
)
expected_df1.columns = col_names
expected_df2 = pd.concat(
[
self.data_df.iloc[2:, 0:-1].reset_index(drop=True),
pd.DataFrame(weights2, columns=["alpha"]),
pd.DataFrame(flags2, columns=["is_pseudolabel"]),
self.data_df.iloc[2:, -1].reset_index(drop=True),
],
ignore_index=True,
axis=1,
)
expected_df2.columns = col_names
expected_dfs = [expected_df1, expected_df2]
for i, file_name in enumerate(files_list):
with self.subTest(msg=f"file_{i}"):
file_path = os.path.join(output_dir, file_name)
Expand Down
42 changes: 39 additions & 3 deletions spade_anomaly_detection/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@

_DATA_ROOT: Final[str] = 'spade_anomaly_detection/example_data/'

WEIGHT_COLUMN_NAME: Final[str] = 'alpha'
PSEUDOLABEL_FLAG_COLUMN_NAME: Final[str] = 'is_pseudolabel'


def load_dataframe(
dataset_name: str,
Expand Down Expand Up @@ -691,12 +694,19 @@ def upload_dataframe_as_bigquery_table(
self,
features: np.ndarray,
labels: np.ndarray,
weights: Optional[np.ndarray] = None,
pseudolabel_flags: Optional[np.ndarray] = None,
) -> None:
"""Uploads the dataframe to BigQuery, create or replace table.
Args:
features: Numpy array of features.
labels: Numpy array of labels.
weights: Optional numpy array of weights.
pseudolabel_flags: Optional numpy array of pseudolabel flags.
Raises:
ValueError: If the metadata has not been read yet.
"""
if not self.input_feature_metadata:
raise ValueError(
Expand All @@ -705,11 +715,31 @@ def upload_dataframe_as_bigquery_table(
'load_tf_dataset_from_bigquery() before this method '
'is called.'
)
combined_data = np.concatenate(
[features, labels.reshape(len(features), 1)], axis=1
)
combined_data = features

# Get the list of feature and label column names.
column_names = list(self.input_feature_metadata.names)

# If the weights are provided, add them to the column names and to the
# combined data.
if weights is not None:
column_names.append(WEIGHT_COLUMN_NAME)
combined_data = np.concatenate(
[combined_data, weights.reshape(len(features), 1)], axis=1
)

# If the pseudolabel flags are provided, add them to the column names and
# to the combined data.
if pseudolabel_flags is not None:
column_names.append(PSEUDOLABEL_FLAG_COLUMN_NAME)
combined_data = np.concatenate(
[combined_data, pseudolabel_flags.reshape(len(features), 1)], axis=1
)

# Make sure the label column is the last column.
combined_data = np.concatenate(
[combined_data, labels.reshape(len(features), 1)], axis=1
)
column_names.remove(self.runner_parameters.label_col_name)
column_names.append(self.runner_parameters.label_col_name)

Expand All @@ -722,6 +752,12 @@ def upload_dataframe_as_bigquery_table(
complete_dataframe[self.runner_parameters.label_col_name].astype('bool')
)

# Adjust pseudolabel flag column type.
if pseudolabel_flags is not None:
complete_dataframe[PSEUDOLABEL_FLAG_COLUMN_NAME] = complete_dataframe[
PSEUDOLABEL_FLAG_COLUMN_NAME
].astype(np.int64)

with bigquery.Client(
project=self.table_parts.project_id
) as big_query_client:
Expand Down
111 changes: 111 additions & 0 deletions spade_anomaly_detection/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,117 @@ def test_bigquery_table_upload_throw_error_metadata(self):
features=features, labels=labels
)

@mock.patch.object(bigquery, 'LoadJobConfig', autospec=True)
def test_upload_dataframe_with_wts_flags_as_bigquery_table_no_error(
self, mock_bqclient_loadjobconfig
):
self.runner_parameters.output_bigquery_table_path = (
'project.dataset.pseudo_labeled_data'
)
data_loader_object = data_loader.DataLoader(self.runner_parameters)
feature_column_names = [
'x1',
'x2',
data_loader.WEIGHT_COLUMN_NAME,
data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME,
self.runner_parameters.label_col_name,
]

features = np.random.rand(10, 2).astype(np.float32)
labels = np.repeat(0, 10).reshape(10, 1).astype(np.int8)
# Two possible values for weight (alpha), repeated 10/2 = 5 times each.
weights = np.repeat([0.1, 1.0], 5).reshape(10, 1).astype(np.float32)
# The corresponding peseudolabel flags are False, True, repeated 5 times.
flags = np.repeat([1, 0], 5).reshape(10, 1).astype(np.int8)

tf_dataset_instance_mock = mock.create_autospec(
tf.data.Dataset, instance=True
)

feature1_metadata = feature_metadata.FeatureMetadata('x1', 0, 'FLOAT64')
feature2_metadata = feature_metadata.FeatureMetadata('x2', 0, 'FLOAT64')
label_metadata = feature_metadata.FeatureMetadata(
self.runner_parameters.label_col_name, 1, 'INT64'
)
metadata_container = feature_metadata.FeatureMetadataContainer(
[feature1_metadata, feature2_metadata, label_metadata]
)

self.mock_bq_dataset.return_value = (
tf_dataset_instance_mock,
metadata_container,
)

# Perform this call so that FeatureMetadata is set.
data_loader_object.load_tf_dataset_from_bigquery(
input_path=self.runner_parameters.input_bigquery_table_path,
label_col_name=self.runner_parameters.label_col_name,
batch_size=self.batch_size,
)

data_loader_object.upload_dataframe_as_bigquery_table(
features=features,
labels=labels,
weights=weights,
pseudolabel_flags=flags,
)
job_config_object = mock_bqclient_loadjobconfig.return_value

load_table_mock_kwargs = (
self.mock_bq_client.return_value.__enter__.return_value.load_table_from_dataframe.call_args.kwargs
)

with self.subTest(name='LabelColumnCorrect'):
self.assertListEqual(
list(
load_table_mock_kwargs['dataframe'][
self.runner_parameters.label_col_name
]
),
list(labels),
)

with self.subTest(name='LabelColumnDataTypeBool'):
self.assertEqual(
load_table_mock_kwargs['dataframe'][
self.runner_parameters.label_col_name
].dtype,
bool,
)

with self.subTest(name='WeightsColumnCorrect'):
self.assertListEqual(
list(
load_table_mock_kwargs['dataframe'][
data_loader.WEIGHT_COLUMN_NAME
]
),
list(weights),
)

with self.subTest(name='PseudolabelFlagsColumnCorrect'):
self.assertListEqual(
list(
load_table_mock_kwargs['dataframe'][
data_loader.PSEUDOLABEL_FLAG_COLUMN_NAME
]
),
list(flags),
)

with self.subTest(name='EqualColumnNames'):
self.assertListEqual(
feature_column_names,
list(load_table_mock_kwargs['dataframe'].columns),
)
with self.subTest(name='EqualDestinationPath'):
self.assertEqual(
self.runner_parameters.output_bigquery_table_path,
load_table_mock_kwargs['destination'],
)
with self.subTest(name='EqualJobConfig'):
self.assertEqual(job_config_object, load_table_mock_kwargs['job_config'])

def test_get_label_thresholds_no_error(self):
mock_query_return_dictionary = {
self.runner_parameters.label_col_name: [
Expand Down
Loading

0 comments on commit 84cbfe7

Please sign in to comment.