Skip to content

Commit

Permalink
Run pyupgrade --py310-plus.
Browse files Browse the repository at this point in the history
Also apply manual fixes to import sorting and unused imports.
  • Loading branch information
hawkinsp committed Jun 26, 2024
1 parent 9842bdb commit fe275e7
Show file tree
Hide file tree
Showing 140 changed files with 387 additions and 379 deletions.
4 changes: 2 additions & 2 deletions build/rocm/run_single_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def extract_filename(path):
def generate_final_report(shell=False, env_vars={}):
env = os.environ
env = {**env, **env_vars}
cmd = ["pytest_html_merger", "-i", '{}'.format(base_dir), "-o", '{}/final_compiled_report.html'.format(base_dir)]
cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html']
result = subprocess.run(cmd,
shell=shell,
capture_output=True,
Expand Down Expand Up @@ -90,7 +90,7 @@ def run_test(testmodule, gpu_tokens):
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
}
testfile = extract_filename(testmodule)
cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", testmodule]
cmd = ["python3", "-m", "pytest", f'--html={base_dir}/{testfile}_log.html', "--reruns", "3", "-x", testmodule]
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
with GPU_LOCK:
gpu_tokens.append(target_gpu)
Expand Down
25 changes: 12 additions & 13 deletions docs/Custom_Operation_for_GPUs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from functools import partial, reduce
import math
from typing import Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -325,9 +324,9 @@ def batcher(batched_args, batch_dims, *, eps):
return RmsNormFwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims

@staticmethod
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.core.ShapedArray]):
def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh,
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos: tuple[jax._src.core.ShapedArray, ...]):
del eps, result_infos # Not needed for this example.
x_info, weight_info = arg_infos
assert len(x_info.shape) == 3
Expand All @@ -340,9 +339,9 @@ def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
return (output_sharding, invvar_sharding)

@staticmethod
def partition(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
def partition(eps: float, mesh : jax.sharding.Mesh,
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]):
del result_infos # Not needed for this example.
x_info, weight_info = arg_infos
assert len(x_info.shape) == 3
Expand Down Expand Up @@ -395,9 +394,9 @@ def batcher(batched_args, batch_dims, *, eps):
return RmsNormBwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims

@staticmethod
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.core.ShapedArray]):
def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh,
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos: tuple[jax._src.core.ShapedArray, ...]):
del eps, result_infos # Not needed for this example.
g_info, invvar_info, x_info, weight_info = arg_infos
assert len(g_info.shape) == 3
Expand All @@ -411,9 +410,9 @@ def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
return (output_sharding, invvar_sharding, output_sharding, )

@staticmethod
def partition(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
def partition(eps: float, mesh : jax.sharding.Mesh,
arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...],
result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]):
del result_infos # Not needed for this example.
g_info, invvar_info, x_info, weight_info = arg_infos
assert len(g_info.shape) == 3
Expand Down
26 changes: 13 additions & 13 deletions docs/autodidax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,15 @@
"source": [
"from collections.abc import Sequence\n",
"from contextlib import contextmanager\n",
"from typing import Optional, Any\n",
"from typing import Any\n",
"\n",
"class MainTrace(NamedTuple):\n",
" level: int\n",
" trace_type: type['Trace']\n",
" global_data: Optional[Any]\n",
" global_data: Any | None\n",
"\n",
"trace_stack: list[MainTrace] = []\n",
"dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3\n",
"dynamic_trace: MainTrace | None = None # to be employed in Part 3\n",
"\n",
"@contextmanager\n",
"def new_main(trace_type: type['Trace'], global_data=None):\n",
Expand Down Expand Up @@ -912,7 +912,7 @@
"source": [
"from collections.abc import Hashable, Iterable, Iterator\n",
"import itertools as it\n",
"from typing import Callable\n",
"from collections.abc import Callable\n",
"\n",
"class NodeType(NamedTuple):\n",
" name: str\n",
Expand Down Expand Up @@ -1651,7 +1651,7 @@
"source": [
"from functools import lru_cache\n",
"\n",
"@lru_cache() # ShapedArrays are hashable\n",
"@lru_cache # ShapedArrays are hashable\n",
"def make_jaxpr_v1(f, *avals_in):\n",
" avals_in, in_tree = tree_flatten(avals_in)\n",
" f, out_tree = flatten_fun(f, in_tree)\n",
Expand Down Expand Up @@ -1803,7 +1803,7 @@
" finally:\n",
" dynamic_trace = prev_dynamic_trace\n",
"\n",
"@lru_cache()\n",
"@lru_cache\n",
"def make_jaxpr(f: Callable, *avals_in: ShapedArray,\n",
" ) -> tuple[Jaxpr, list[Any], PyTreeDef]:\n",
" avals_in, in_tree = tree_flatten(avals_in)\n",
Expand Down Expand Up @@ -1994,7 +1994,7 @@
" return execute(*args)\n",
"impl_rules[xla_call_p] = xla_call_impl\n",
"\n",
"@lru_cache()\n",
"@lru_cache\n",
"def xla_callable(hashable_jaxpr: IDHashable,\n",
" hashable_consts: tuple[IDHashable, ...]):\n",
" jaxpr: Jaxpr = hashable_jaxpr.val\n",
Expand Down Expand Up @@ -2227,7 +2227,7 @@
" return primals_out, tangents_out\n",
"jvp_rules[xla_call_p] = xla_call_jvp_rule\n",
"\n",
"@lru_cache()\n",
"@lru_cache\n",
"def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:\n",
" def jvp_traceable(*primals_and_tangents):\n",
" n = len(primals_and_tangents) // 2\n",
Expand All @@ -2253,7 +2253,7 @@
" return outs, [0] * len(outs)\n",
"vmap_rules[xla_call_p] = xla_call_vmap_rule\n",
"\n",
"@lru_cache()\n",
"@lru_cache\n",
"def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]\n",
" ) -> tuple[Jaxpr, list[Any]]:\n",
" vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n",
Expand Down Expand Up @@ -2638,7 +2638,7 @@
"source": [
"class PartialVal(NamedTuple):\n",
" aval: ShapedArray\n",
" const: Optional[Any]\n",
" const: Any | None\n",
"\n",
" @classmethod\n",
" def known(cls, val: Any):\n",
Expand Down Expand Up @@ -2727,7 +2727,7 @@
"source": [
"class PartialEvalTracer(Tracer):\n",
" pval: PartialVal\n",
" recipe: Optional[JaxprRecipe]\n",
" recipe: JaxprRecipe | None\n",
"\n",
" def __init__(self, trace, pval, recipe):\n",
" self._trace = trace\n",
Expand Down Expand Up @@ -2974,7 +2974,7 @@
"partial_eval_rules[xla_call_p] = xla_call_partial_eval\n",
"\n",
"def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],\n",
" instantiate: Optional[list[bool]] = None,\n",
" instantiate: list[bool] | None = None,\n",
" ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:\n",
" env: dict[Var, bool] = {}\n",
" residuals: set[Var] = set()\n",
Expand Down Expand Up @@ -3271,7 +3271,7 @@
" return [next(outs) if undef else None for undef in undef_primals]\n",
"transpose_rules[xla_call_p] = xla_call_transpose_rule\n",
"\n",
"@lru_cache()\n",
"@lru_cache\n",
"def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]\n",
" ) -> tuple[Jaxpr, list[Any]]:\n",
" avals_in, avals_out = typecheck_jaxpr(jaxpr)\n",
Expand Down
26 changes: 13 additions & 13 deletions docs/autodidax.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,15 @@ more descriptive.
```{code-cell}
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Optional, Any
from typing import Any
class MainTrace(NamedTuple):
level: int
trace_type: type['Trace']
global_data: Optional[Any]
global_data: Any | None
trace_stack: list[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3
dynamic_trace: MainTrace | None = None # to be employed in Part 3
@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
Expand Down Expand Up @@ -705,7 +705,7 @@ class Store:
from collections.abc import Hashable, Iterable, Iterator
import itertools as it
from typing import Callable
from collections.abc import Callable
class NodeType(NamedTuple):
name: str
Expand Down Expand Up @@ -1295,7 +1295,7 @@ transformation and a pretty-printer:
```{code-cell}
from functools import lru_cache
@lru_cache() # ShapedArrays are hashable
@lru_cache # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
Expand Down Expand Up @@ -1415,7 +1415,7 @@ def new_dynamic(main: MainTrace):
finally:
dynamic_trace = prev_dynamic_trace
@lru_cache()
@lru_cache
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
) -> tuple[Jaxpr, list[Any], PyTreeDef]:
avals_in, in_tree = tree_flatten(avals_in)
Expand Down Expand Up @@ -1564,7 +1564,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
return execute(*args)
impl_rules[xla_call_p] = xla_call_impl
@lru_cache()
@lru_cache
def xla_callable(hashable_jaxpr: IDHashable,
hashable_consts: tuple[IDHashable, ...]):
jaxpr: Jaxpr = hashable_jaxpr.val
Expand Down Expand Up @@ -1734,7 +1734,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache()
@lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2
Expand All @@ -1755,7 +1755,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
@lru_cache()
@lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
) -> tuple[Jaxpr, list[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
Expand Down Expand Up @@ -2065,7 +2065,7 @@ be either known or unknown:
```{code-cell}
class PartialVal(NamedTuple):
aval: ShapedArray
const: Optional[Any]
const: Any | None
@classmethod
def known(cls, val: Any):
Expand Down Expand Up @@ -2129,7 +2129,7 @@ JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
```{code-cell}
class PartialEvalTracer(Tracer):
pval: PartialVal
recipe: Optional[JaxprRecipe]
recipe: JaxprRecipe | None
def __init__(self, trace, pval, recipe):
self._trace = trace
Expand Down Expand Up @@ -2329,7 +2329,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
instantiate: Optional[list[bool]] = None,
instantiate: list[bool] | None = None,
) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
env: dict[Var, bool] = {}
residuals: set[Var] = set()
Expand Down Expand Up @@ -2586,7 +2586,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache()
@lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
) -> tuple[Jaxpr, list[Any]]:
avals_in, avals_out = typecheck_jaxpr(jaxpr)
Expand Down
Loading

0 comments on commit fe275e7

Please sign in to comment.