From bc7f6a954a94e085b7a775ded219a0a2b44af4a1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 8 Jul 2023 18:14:29 +0000 Subject: [PATCH] Remove trace_only for scripted ops using internal functions | feat(torchlib) --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 391eb86ee..3d503cc3f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1345,7 +1345,7 @@ def _complex_conjugate(self: TFloat) -> TFloat: return conjugated -@torch_op("aten::conj", complex=True, trace_only=True) +@torch_op("aten::conj", complex=True) def aten_conj_complex(self: TFloat) -> TFloat: """conj(Tensor(a) self) -> Tensor(a)""" @@ -3749,7 +3749,7 @@ def aten_mH(self: TRealOrUInt8) -> TRealOrUInt8: return op.Einsum(self, equation="...ij->...ji") -@torch_op("aten::mH", complex=True, trace_only=True) +@torch_op("aten::mH", complex=True) def aten_mH_complex(self: TFloat) -> TFloat: """mH(Tensor(a) self) -> Tensor(a)""" diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index aa01690b7..148d7bfc4 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -521,7 +521,7 @@ def _where_input_wrangler( reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), TorchLibOpInfo("conj", core_ops.aten_conj), - TorchLibOpInfo("conj", core_ops.aten_conj_complex, complex=True, trace_only=True), + TorchLibOpInfo("conj", core_ops.aten_conj_complex, complex=True), TorchLibOpInfo("constant_pad_nd", core_ops.aten_constant_pad_nd), # TorchLibOpInfo("copy", core_ops.aten_copy), # copy is not in OPS_DB TorchLibOpInfo("cos", core_ops.aten_cos), @@ -679,7 +679,7 @@ def _where_input_wrangler( reason="this Aten overload can accept 2 inputs:(self, dim)", ), TorchLibOpInfo("mH", core_ops.aten_mH), - TorchLibOpInfo("mH", core_ops.aten_mH_complex, complex=True, trace_only=True), + TorchLibOpInfo("mH", core_ops.aten_mH_complex, complex=True), TorchLibOpInfo("mT", core_ops.aten_mT), TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True), TorchLibOpInfo("min_dim", core_ops.aten_min_dim)