Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent large objects from being stored in the RTIF #38094

Merged
merged 16 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,14 @@ core:
type: string
example: ~
default: "Disabled"
max_templated_field_length:
description: |
The maximum length of the rendered template field. If the value to be stored in the
rendered template field exceeds this size, it's redacted.
version_added: 2.9.0
type: integer
example: ~
default: "4096"
database:
description: ~
options:
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def __init__(self, ti: TaskInstance, render_templates=True):

self.k8s_pod_yaml = render_k8s_pod_yaml(ti)
self.rendered_fields = {
field: serialize_template_field(getattr(self.task, field)) for field in self.task.template_fields
field: serialize_template_field(getattr(self.task, field), field)
for field in self.task.template_fields
}

self._redact()
Expand Down
22 changes: 21 additions & 1 deletion airflow/serialization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

from typing import Any

from airflow.configuration import conf
from airflow.settings import json
from airflow.utils.log.secrets_masker import redact


def serialize_template_field(template_field: Any) -> str | dict | list | int | float:
def serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float:
"""Return a serializable representation of the templated field.

If ``templated_field`` contains a class or instance that requires recursive
Expand All @@ -38,7 +40,25 @@ def is_jsonable(x):
else:
return True

max_length = conf.getint("core", "max_templated_field_length")

if not is_jsonable(template_field):
serialized = str(template_field)
if len(serialized) > max_length:
rendered = redact(serialized, name)
return (
"Truncated. You can change this behaviour in [core]max_templated_field_length. "
f"{rendered[:max_length - 79]!r}... "
)
return str(template_field)
else:
if not template_field:
return template_field
serialized = str(template_field)
if len(serialized) > max_length:
rendered = redact(serialized, name)
return (
"Truncated. You can change this behaviour in [core]max_templated_field_length. "
f"{rendered[:max_length - 79]!r}... "
)
return template_field
2 changes: 1 addition & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool)
)
value = getattr(op, template_field, None)
if not cls._is_excluded(value, template_field, op):
serialize_op[template_field] = serialize_template_field(value)
serialize_op[template_field] = serialize_template_field(value, template_field)

if op.params:
serialize_op["params"] = cls._serialize_params_dict(op.params)
Expand Down
5 changes: 5 additions & 0 deletions newsfragments/38094.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Prevent large string objects from being stored in the Rendered Template Fields

There's now a limit to the length of data that can be stored in the Rendered Template Fields.
The limit is set to 4096 characters. If the data exceeds this limit, it will be truncated. You can change this limit
by setting the ``[core]max_template_field_length`` configuration option in your airflow config.
60 changes: 60 additions & 0 deletions tests/models/test_renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import pytest

from airflow import settings
from airflow.configuration import conf
from airflow.decorators import task as task_decorator
from airflow.models import Variable
from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
from airflow.operators.bash import BashOperator
Expand Down Expand Up @@ -62,6 +64,17 @@ def __ne__(self, other):
return not self.__eq__(other)


class LargeStrObject:
def __init__(self):
self.a = "a" * 5000

def __str__(self):
return self.a


max_length = conf.getint("core", "max_templated_field_length")


class TestRenderedTaskInstanceFields:
"""Unit tests for RenderedTaskInstanceFields."""

Expand Down Expand Up @@ -111,6 +124,14 @@ def teardown_method(self):
"{'att3': '{{ task.task_id }}', 'att4': '{{ task.task_id }}', 'template_fields': ['att3']}), "
"'template_fields': ['nested1']})",
),
(
"a" * 5000,
f"Truncated. You can change this behaviour in [core]max_templated_field_length. {('a'*5000)[:max_length-79]!r}... ",
),
(
LargeStrObject(),
f"Truncated. You can change this behaviour in [core]max_templated_field_length. {str(LargeStrObject())[:max_length-79]!r}... ",
),
],
)
def test_get_templated_fields(self, templated_field, expected_rendered_field, dag_maker):
Expand Down Expand Up @@ -148,6 +169,45 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field, da
# Fetching them will return None
assert RTIF.get_templated_fields(ti=ti2) is None

def test_secrets_are_masked_when_large_string(self, dag_maker):
"""
Test that secrets are masked when the templated field is a large string
"""
Variable.set(
key="api_key",
value="test api key are still masked" * 5000,
)
with dag_maker("test_serialized_rendered_fields"):
task = BashOperator(task_id="test", bash_command="echo {{ var.value.api_key }}")
dr = dag_maker.create_dagrun()
ti = dr.task_instances[0]
ti.task = task
rtif = RTIF(ti=ti)
assert "***" in rtif.rendered_fields.get("bash_command")

@mock.patch("airflow.models.BaseOperator.render_template")
def test_pandas_dataframes_works_with_the_string_compare(self, render_mock, dag_maker):
"""Test that rendered dataframe gets passed through the serialized template fields."""
import pandas

render_mock.return_value = pandas.DataFrame({"a": [1, 2, 3]})
with dag_maker("test_serialized_rendered_fields"):

@task_decorator
def generate_pd():
return pandas.DataFrame({"a": [1, 2, 3]})

@task_decorator
def consume_pd(data):
return data

consume_pd(generate_pd())

dr = dag_maker.create_dagrun()
ti, ti2 = dr.task_instances
rtif = RTIF(ti=ti2)
rtif.write()

@pytest.mark.parametrize(
"rtif_num, num_to_keep, remaining_rtifs, expected_query_count",
[
Expand Down
Loading