Skip to content

Commit

Permalink
fix: don't reset context.scratch between files (#1151)
Browse files Browse the repository at this point in the history
#453 fixed scratch leaking between files by setting it to empty, but that drops all the scratch space that was set up before the codemod runs (e.g. in the transformer's constructor)

This PR improves the fix by preserving the initial scratch.
  • Loading branch information
zsol committed May 21, 2024
1 parent 71b0a12 commit db696e6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
5 changes: 4 additions & 1 deletion libcst/codemod/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import sys
import time
import traceback
from copy import deepcopy
from dataclasses import dataclass, replace
from multiprocessing import cpu_count, Pool
from pathlib import Path
Expand Down Expand Up @@ -214,6 +215,7 @@ def _execute_transform( # noqa: C901
transformer: Codemod,
filename: str,
config: ExecutionConfig,
scratch: Dict[str, object],
) -> ExecutionResult:
for pattern in config.blacklist_patterns:
if re.fullmatch(pattern, filename):
Expand Down Expand Up @@ -251,7 +253,7 @@ def _execute_transform( # noqa: C901
transformer.context = replace(
transformer.context,
filename=filename,
scratch={},
scratch=deepcopy(scratch),
)

# determine the module and package name for this file
Expand Down Expand Up @@ -634,6 +636,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
"transformer": transform,
"filename": filename,
"config": config,
"scratch": transform.context.scratch,
}
for filename in files
]
Expand Down
9 changes: 8 additions & 1 deletion libcst/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Generator
from typing import Dict, Generator
from unittest import TestCase

from libcst import BaseExpression, Call, matchers as m, Name
Expand All @@ -16,7 +16,14 @@


class PrintToPPrintCommand(VisitorBasedCodemodCommand):
def __init__(self, context: CodemodContext, **kwargs: Dict[str, object]) -> None:
super().__init__(context, **kwargs)
self.context.scratch["PPRINT_WAS_HERE"] = True

def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression:
if not self.context.scratch["PPRINT_WAS_HERE"]:
raise AssertionError("Scratch space lost")

if m.matches(updated_node, m.Call(func=m.Name("print"))):
AddImportsVisitor.add_needed_import(
self.context,
Expand Down

0 comments on commit db696e6

Please sign in to comment.