From 1aa7a7017225f95621e9a2008041056661ff7977 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 20 Jun 2024 17:56:27 -0700 Subject: [PATCH] [torchlib] Fix more signatures (#1613) Fix more signatures in torchlib that were previously overlooked --- .../function_libs/torch_lib/ops/core.py | 566 ++++++++++-------- onnxscript/function_libs/torch_lib/ops/nn.py | 39 +- tests/function_libs/torch_lib/extra_opinfo.py | 81 --- tests/function_libs/torch_lib/ops_test.py | 35 +- .../function_libs/torch_lib/ops_test_data.py | 239 ++------ 5 files changed, 389 insertions(+), 571 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index bc20bb3f9..e50489c38 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -9,6 +9,7 @@ - All functions should not have the script() decorator. This is because we want to delay the compilation of the function. """ +# pylint: disable=unused-argument from __future__ import annotations @@ -93,7 +94,7 @@ def aten__log_softmax_half( def aten__log_softmax( self: TFloatHighPrecision, dim: int, - half_to_float: bool, # pylint: disable=unused-argument + half_to_float: bool, ) -> TFloatHighPrecision: """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" @@ -303,11 +304,11 @@ def aten_affine_grid_generator_backward( raise NotImplementedError() -@torch_op("aten::alias") +@torch_op("aten::alias", trace_only=True) def aten_alias(self: TTensor) -> TTensor: """alias(Tensor(a) self) -> Tensor(a)""" - return op.Identity(self) + return self def aten_alias_copy(self: TensorType) -> TensorType: @@ -398,7 +399,7 @@ def aten_allclose( other: TReal, rtol: float = 1e-05, atol: float = 1e-08, - equal_nan: bool = False, # pylint: disable=unused-argument + equal_nan: bool = False, ) -> BOOL: """allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool""" @@ -538,7 +539,13 @@ def _integral_to_be_adjusted(dtype: int) -> bool: @torch_op("aten::arange", trace_only=True) -def aten_arange(end: Union[DOUBLE, FLOAT, INT16, INT32, INT64], dtype: int = -1) -> TensorType: +def aten_arange( + end: Union[DOUBLE, FLOAT, INT16, INT32, INT64], + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" # NOTE: trace_only because both if branches need to be the same type, but we have @@ -568,7 +575,12 @@ def aten_arange(end: Union[DOUBLE, FLOAT, INT16, INT32, INT64], dtype: int = -1) @torch_op("aten::arange.start", trace_only=True) def aten_arange_start( - start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, dtype: int = -1 + start: TRealUnlessFloat16OrInt8, + end: TRealUnlessFloat16OrInt8, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TensorType: """arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -619,6 +631,9 @@ def aten_arange_start_step( end: TRealUnlessFloat16OrInt8, step: TRealUnlessFloat16OrInt8, dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TensorType: """arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -1446,7 +1461,7 @@ def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::bmm") +@torch_op("aten::bmm", traceable=True) def aten_bmm(self: TFloat, mat2: TFloat) -> TFloat: """bmm(Tensor self, Tensor mat2) -> Tensor""" @@ -1669,14 +1684,14 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal: return result -@torch_op("aten::clone") +@torch_op("aten::clone", trace_only=True) def aten_clone( self: TTensor, - memory_format: str = "", # pylint: disable=unused-argument + memory_format: str = "", ) -> TTensor: """clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor""" - return op.Identity(self) + return self def aten_coalesce(self: TensorType) -> TensorType: @@ -1730,11 +1745,11 @@ def aten_complex(real: TFloat, imag: TFloat) -> TFloat: return _aten_complex(real, imag) -@torch_op("aten::conj") +@torch_op("aten::conj", trace_only=True) def aten_conj(self: TTensor) -> TTensor: """conj(Tensor(a) self) -> Tensor(a)""" - return op.Identity(self) + return self @torch_op("aten::conj", complex=True, private=True) @@ -1802,15 +1817,15 @@ def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTens return op.Pad(self, onnx_padding, value) -@torch_op("aten::contiguous") +@torch_op("aten::contiguous", trace_only=True) def aten_contiguous( self: TTensor, - memory_format: str = "contiguous_format", # pylint: disable=unused-argument + memory_format: str = "contiguous_format", ) -> TTensor: """contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)""" # ONNX does not have the notion of memory_format. It is always treated as a no-op. - return op.Identity(self) + return self @torch_op("aten::conv1d", trace_only=True) @@ -2026,12 +2041,12 @@ def aten_convolution( return result -@torch_op("aten::convolution", private=True, traceable=True) +@torch_op("aten::convolution", private=True, trace_only=True) def _aten_convolution_onnx( input: TFloat, weight: TFloat, bias: TFloat, - transposed: BOOL, + transposed: bool, strides: Sequence[int], pads: Sequence[int], dilations: Sequence[int], @@ -2045,7 +2060,7 @@ def _aten_convolution_onnx( # Alternatively we could cast transposed to BOOL. # E.g. `if op.Cast(transposed, BOOL.dtype): ...` - no_batch = Rank(input) != Rank(weight) + no_batch = len(input.shape) != len(weight.shape) if no_batch: input = op.Unsqueeze(input, op.Constant(value_ints=[0])) @@ -2133,7 +2148,7 @@ def aten_convolution_overrideable( def aten_copy( self: TTensor, src: TTensor2, - non_blocking: bool = False, # pylint: disable=unused-argument + non_blocking: bool = False, ) -> TTensor: """copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor""" @@ -2144,16 +2159,16 @@ def aten_copy( def aten__to_copy( self: TTensor, dtype: int = -1, - layout: str = "", # pylint: disable=unused-argument - device: str = "", # pylint: disable=unused-argument - pin_memory: bool = False, # pylint: disable=unused-argument - non_blocking: bool = False, # pylint: disable=unused-argument - memory_format: str = "", # pylint: disable=unused-argument + layout: str = "", + device: str = "", + pin_memory: bool = False, + non_blocking: bool = False, + memory_format: str = "", ) -> TTensor: """_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor""" if dtype == -1: - return op.Identity(self) + return self else: return common_ops.cast_to(self, dtype=dtype) @@ -2474,11 +2489,11 @@ def aten_dense_dim(self: TensorType) -> int: raise NotImplementedError() -@torch_op("aten::detach") +@torch_op("aten::detach", trace_only=True) def aten_detach(self: TensorType) -> TensorType: """detach(Tensor(a) self) -> Tensor(a)""" - return op.Identity(self) + return self def aten_detach_copy(self: TensorType) -> TensorType: @@ -2841,7 +2856,7 @@ def aten_dstack(tensors: Sequence[TensorType]) -> TensorType: def aten_einsum( equation: str, tensors: Sequence[TReal], - path: Optional[int] = None, # pylint: disable=unused-argument + path: Optional[int] = None, ) -> TReal: """einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor""" @@ -2849,14 +2864,14 @@ def aten_einsum( return op.Einsum(*tensors, equation=equation) -@torch_op("aten::embedding") +@torch_op("aten::embedding", traceable=True) def aten_embedding( weight: TTensor, indices: TInt, padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False, -): # pylint: disable=unused-argument +) -> TTensor: # embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor return op.Gather(weight, indices) @@ -2880,9 +2895,9 @@ def aten_embedding_bag( weight: TFloat, indices: INT64, offsets: INT64, - scale_grad_by_freq: bool = False, # pylint: disable=unused-argument + scale_grad_by_freq: bool = False, mode: int = 0, # [0,1,2] indicate ["sum", "mean", "max"] - sparse: bool = False, # pylint: disable=unused-argument + sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, ) -> Tuple[TFloat, TFloat, TFloat, TFloat]: @@ -3014,9 +3029,9 @@ def aten_embedding_bag_padding_idx( weight: TFloat, indices: INT64, offsets: INT64, - scale_grad_by_freq: bool = False, # pylint: disable=unused-argument + scale_grad_by_freq: bool = False, mode: int = 0, # [0,1,2] indicate ["sum", "mean", "max"] - sparse: bool = False, # pylint: disable=unused-argument + sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, padding_idx: int = -1, @@ -3202,10 +3217,18 @@ def aten_embedding_sparse_backward( raise NotImplementedError() -@torch_op(("aten::empty", "aten::empty.memory_format")) -def aten_empty(size: IntType, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var] +@torch_op("aten::empty.memory_format", trace_only=True) +def aten_empty( + size: IntType, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, + memory_format: str = "", +) -> TensorType: # type: ignore[type-var] # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - + if dtype == -1: + dtype = FLOAT.dtype # using Zeros to simulate np.empty() size = op.Cast(size, to=INT64.dtype) zero = op.Constant(value_float=0.0) @@ -3246,7 +3269,10 @@ def aten_empty_quantized( @torch_op("aten::empty_strided") def aten_empty_strided( size: INT64, - stride: INT64, # pylint: disable=unused-argument + stride: INT64, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TTensor: # type: ignore[type-var] # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -3470,7 +3496,7 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType raise NotImplementedError() -@torch_op(("aten::fill", "aten::fill.Tensor")) +@torch_op("aten::fill.Tensor") def aten_fill(self: TTensor, value: TTensor) -> TTensor: """fill.Tensor(Tensor self, Tensor value) -> Tensor""" @@ -3583,32 +3609,40 @@ def aten_from_file( raise NotImplementedError() -@torch_op("aten::full") -def aten_full(size: INT64, fill_value: FLOAT, dtype: int = FLOAT.dtype): +@torch_op("aten::full", trace_only=True) +def aten_full( + size: INT64, + fill_value: FLOAT, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +): """full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" size = op.Cast(size, to=INT64.dtype) - fill_value = op.Cast(fill_value, to=dtype) + if dtype != -1: + fill_value = op.Cast(fill_value, to=dtype) return op.Expand(fill_value, size) -@torch_op("aten::full_like") -def aten_full_like(self: TTensor, fill_value: TTensor) -> TTensor: +@torch_op("aten::full_like", trace_only=True) +def aten_full_like( + self: TTensor, + fill_value: TTensor, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TTensor: """full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - fill_value = op.CastLike(fill_value, self) - self_shape = op.Shape(self) - - return op.Expand(fill_value, self_shape) - - -@torch_op("aten::full_like") -def aten_full_like_dtype(self: TTensor, fill_value: TTensor, dtype: int) -> TTensor: - """full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + if dtype == -1: + fill_value = op.CastLike(fill_value, self) + else: + fill_value = op.Cast(fill_value, to=dtype) - fill_value = op.Cast(fill_value, to=dtype) self_shape = op.Shape(self) - return op.Expand(fill_value, self_shape) @@ -3637,7 +3671,7 @@ def aten_gather( self: TReal, dim: int, index: TInt, - sparse_grad: bool = False, # pylint: disable=unused-argument + sparse_grad: bool = False, ) -> TReal: """gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor""" @@ -4454,7 +4488,7 @@ def aten_isclose( other: TReal, rtol: float = 1e-05, atol: float = 1e-08, - equal_nan: bool = False, # pylint: disable=unused-argument + equal_nan: bool = False, ) -> BOOL: """isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor""" @@ -4662,11 +4696,11 @@ def aten_lift_fresh(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::lift_fresh_copy") +@torch_op("aten::lift_fresh_copy", trace_only=True) def aten_lift_fresh_copy(self: TensorType) -> TensorType: """lift_fresh_copy(Tensor self) -> Tensor""" - return op.Identity(self) + return self def aten_linear_backward( @@ -4683,6 +4717,9 @@ def aten_linspace( ) -> TensorType: """linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + if dtype == -1: + dtype = FLOAT.dtype + # Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896 if steps == 0: return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype) @@ -5448,7 +5485,7 @@ def aten_mkldnn_max_pool3d_backward( raise NotImplementedError() -@torch_op("aten::mm") +@torch_op("aten::mm", traceable=True) def aten_mm( self: TRealUnlessInt16OrInt8, mat2: TRealUnlessInt16OrInt8 ) -> TRealUnlessInt16OrInt8: @@ -5516,7 +5553,7 @@ def aten_msort(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul")) +@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"), traceable=True) def aten_mul(self: TReal, other: TReal) -> TReal: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -5560,7 +5597,7 @@ def aten_mul_complex(self: TReal, other: TReal) -> TReal: def aten_multinomial( self: TFloat, num_samples: int, - replacement: bool = False, # pylint: disable=unused-argument + replacement: bool = False, ) -> TInt: """multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor""" # ONNX Multinomial doesn't support 1D input @@ -5976,9 +6013,9 @@ def aten_native_group_norm( input: TFloat, weight: Optional[TFloat] = None, bias: Optional[TFloat] = None, - N: Optional[INT64] = None, # pylint: disable=unused-argument - C: Optional[INT64] = None, # pylint: disable=unused-argument - HxW: Optional[INT64] = None, # pylint: disable=unused-argument + N: Optional[INT64] = None, + C: Optional[INT64] = None, + HxW: Optional[INT64] = None, group: int = 1, eps: float = 1e-05, ) -> Tuple[TFloat, TFloat, TFloat]: @@ -6136,111 +6173,94 @@ def aten_negative(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::new_empty") -def aten_new_empty(self: TTensor, size: INT64) -> TTensor: - """new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - - # using zero to simulate empty array - result = op.ConstantOfShape(size) - return op.CastLike(result, self) - - -@torch_op("aten::new_empty") -def aten_new_empty_dtype( - self: TTensor, # pylint: disable=unused-argument +@torch_op("aten::new_empty", trace_only=True) +def aten_new_empty( + self: TTensor, size: INT64, - dtype: int, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TTensor: """new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" # using zero to simulate empty array result = op.ConstantOfShape(size) + if dtype == -1: + return op.CastLike(result, self) return op.Cast(result, to=dtype) -@torch_op("aten::new_empty_strided") +@torch_op("aten::new_empty_strided", trace_only=True) def aten_new_empty_strided( self: TTensor, size: INT64, - stride: INT64, # pylint: disable=unused-argument -) -> TTensor: - """new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - - # using zero to simulate empty array - zero = op.ConstantOfShape(size) - return op.CastLike(zero, self) - - -@torch_op("aten::new_empty_strided") -def aten_new_empty_strided_dtype( - self: TTensor, # pylint: disable=unused-argument - size: INT64, - stride: INT64, # pylint: disable=unused-argument - dtype: int, + stride: INT64, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TTensor: """new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" # using zero to simulate empty array zero = op.ConstantOfShape(size) + if dtype == -1: + return op.CastLike(zero, self) return op.Cast(zero, to=dtype) -@torch_op("aten::new_full") -def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: - # new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - - fill_value = op.CastLike(fill_value, self) - return op.Expand(fill_value, size) - - -@torch_op("aten::new_full") -def aten_new_full_dtype( - self: TTensor, # pylint: disable=unused-argument +@torch_op("aten::new_full", trace_only=True) +def aten_new_full( + self: TTensor, size: INT64, fill_value: TTensor, - dtype: int, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TTensor: # new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - fill_value = op.Cast(fill_value, to=dtype) + if dtype == -1: + fill_value = op.CastLike(fill_value, self) + else: + fill_value = op.Cast(fill_value, to=dtype) return op.Expand(fill_value, size) -@torch_op("aten::new_ones") -def aten_new_ones(self: TReal, size: INT64) -> TReal: # pylint: disable=unused-argument +@torch_op("aten::new_ones", trace_only=True) +def aten_new_ones( + self: TReal, + size: INT64, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TReal: """new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" one = op.Constant(value_float=1.0) result = op.Expand(one, size) - return op.CastLike(result, self) + if dtype == -1: + return op.CastLike(result, self) + return op.Cast(result, to=dtype) -@torch_op("aten::new_ones") -def aten_new_ones_dtype( - self: TReal, # pylint: disable=unused-argument +@torch_op("aten::new_zeros", trace_only=True) +def aten_new_zeros( + self: TReal, size: INT64, - dtype: int, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TReal: - one = op.Constant(value_float=1.0) - result = op.Expand(one, size) - return op.Cast(result, to=dtype) - - -@torch_op("aten::new_zeros") -def aten_new_zeros(self: TReal, size: INT64) -> TReal: """new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" result = op.ConstantOfShape(size) - return op.CastLike(result, self) - - -@torch_op("aten::new_zeros") -def aten_new_zeros_dtype( - self: TReal, # pylint: disable=unused-argument - size: INT64, - dtype: int, -) -> TReal: - result = op.ConstantOfShape(size) + if dtype == -1: + return op.CastLike(result, self) return op.Cast(result, to=dtype) @@ -6270,7 +6290,16 @@ def aten_norm_except_dim(v: TensorType, pow: int = 2, dim: int = 0) -> TensorTyp raise NotImplementedError() -@torch_op(("aten::normal", "aten::normal_functional"), traceable=True) +@torch_op( + ( + "aten::normal.Tensor_float", + "aten::normal.Tensor_Tensor", + "aten::normal.float_Tensor", + "aten::normal.float_float", + "aten::normal_functional", + ), + traceable=True, +) def aten_normal( self: TTensor, mean: float = 0.0, @@ -6285,12 +6314,14 @@ def aten_normal( return result -@torch_op("aten::normal.float_float") +@torch_op("aten::normal.float_float", trace_only=True) def aten_normal_float_float( mean: float, std: float, size: INT64, dtype: int = FLOAT.dtype ) -> TensorType: """normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + if dtype == -1: + dtype = FLOAT.dtype # Create a dummy tensor for RandomNormalLike to get the shape dummy_tensor = op.ConstantOfShape(size) result = op.RandomNormalLike(dummy_tensor, mean=mean, scale=std) @@ -6337,10 +6368,17 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType: raise NotImplementedError() -@torch_op("aten::ones") -def aten_ones(size: IntType, dtype: int = FLOAT.dtype): +@torch_op("aten::ones", trace_only=True) +def aten_ones( + size: IntType, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +): """ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - + if dtype == -1: + dtype = FLOAT.dtype size = op.Cast(size, to=INT64.dtype) one = op.Constant(value_float=1.0) one = op.Cast(one, to=dtype) @@ -6348,7 +6386,13 @@ def aten_ones(size: IntType, dtype: int = FLOAT.dtype): @torch_op("aten::ones_like", trace_only=True) -def aten_ones_like(self: TTensor, dtype: int = -1) -> TTensor: +def aten_ones_like( + self: TTensor, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TTensor: """ones_like. Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype @@ -6769,30 +6813,41 @@ def aten_rad2deg(self: TFloat) -> TFloat: return op.Mul(self, op.CastLike(180.0 / _MATH_PI, self)) -@torch_op("aten::rand") -def aten_rand(size: INT64, dtype: int = FLOAT.dtype) -> TReal: +@torch_op("aten::rand", trace_only=True) +def aten_rand( + size: INT64, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TReal: """rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - + if dtype == -1: + dtype = FLOAT.dtype shaper = op.ConstantOfShape(size) return op.RandomUniformLike(shaper, dtype=dtype) -@torch_op("aten::rand_like") -def aten_rand_like(self: TFloat) -> TFloat: - """rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - - return op.RandomUniformLike(self) - - -@torch_op("aten::rand_like") -def aten_rand_like_dtype(self: TensorType, dtype: int) -> TensorType: +@torch_op("aten::rand_like", trace_only=True) +def aten_rand_like( + self: TFloat, dtype: int = -1, layout: str = "", device: str = "", pin_memory: bool = False +) -> TFloat: """rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + if dtype == -1: + return op.RandomUniformLike(self) return op.RandomUniformLike(self, dtype=dtype) -@torch_op("aten::randint") -def aten_randint(high: INT64, size: INT64, dtype: int = INT64.dtype) -> TensorType: +@torch_op("aten::randint", trace_only=True) +def aten_randint( + high: INT64, + size: INT64, + dtype: int = INT64.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" shaper = op.ConstantOfShape(size) @@ -6804,9 +6859,15 @@ def aten_randint(high: INT64, size: INT64, dtype: int = INT64.dtype) -> TensorTy return op.Cast(rand_int, to=dtype) -@torch_op("aten::randint.low") +@torch_op("aten::randint.low", trace_only=True) def aten_randint_low( - low: INT64, high: INT64, size: INT64, dtype: int = INT64.dtype + low: INT64, + high: INT64, + size: INT64, + dtype: int = INT64.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TensorType: """randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -6821,21 +6882,15 @@ def aten_randint_low( return op.Cast(rand_int, to=dtype) -@torch_op("aten::randint_like") -def aten_randint_like(self: TensorType, high: INT64) -> IntType: - """randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - - self_float = op.Cast(self, to=FLOAT.dtype) - rand = op.RandomUniformLike(self_float) - # Scale to [0, high] first - rand_scaled = op.Mul(rand, op.CastLike(high, rand)) - # Round to ints - rand_int = op.Floor(rand_scaled) - return op.CastLike(rand_int, self) - - -@torch_op("aten::randint_like") -def aten_randint_like_dtype(self: TensorType, high: INT64, dtype: int) -> TensorType: +@torch_op("aten::randint_like", trace_only=True) +def aten_randint_like( + self: TensorType, + high: INT64, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> IntType: """randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" self_float = op.Cast(self, to=FLOAT.dtype) @@ -6844,11 +6899,21 @@ def aten_randint_like_dtype(self: TensorType, high: INT64, dtype: int) -> Tensor rand_scaled = op.Mul(rand, op.CastLike(high, rand)) # Round to ints rand_int = op.Floor(rand_scaled) + if dtype == -1: + return op.CastLike(rand_int, self) return op.Cast(rand_int, to=dtype) -@torch_op("aten::randint_like.low_dtype") -def aten_randint_like_low_dtype(self: TensorType, low: INT64, high: INT64) -> IntType: +@torch_op("aten::randint_like.low_dtype", trace_only=True) +def aten_randint_like_low_dtype( + self: TensorType, + low: INT64, + high: INT64, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> IntType: """randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor This is the TorchLib overload for aten::randint_like.low_dtype when dtype is None. @@ -6862,55 +6927,47 @@ def aten_randint_like_low_dtype(self: TensorType, low: INT64, high: INT64) -> In rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low) # Round to ints rand_int = op.Floor(rand_translated) - return op.CastLike(rand_int, self) - - -@torch_op("aten::randint_like.low_dtype") -def aten_randint_like_low_dtype_dtype( - self: TensorType, low: INT64, high: INT64, dtype: int -) -> TensorType: - """randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - - self_float = op.Cast(self, to=FLOAT.dtype) - rand = op.RandomUniformLike(self_float) - # Translate to [low, high] first - high = op.Cast(high, to=FLOAT.dtype) - low = op.Cast(low, to=FLOAT.dtype) - rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low) - # Round to ints - rand_int = op.Floor(rand_translated) + if dtype == -1: + return op.CastLike(rand_int, self) return op.Cast(rand_int, to=dtype) -@torch_op("aten::randn") -def aten_randn(size: INT64, dtype: int = FLOAT.dtype) -> TReal: +@torch_op("aten::randn", trace_only=True) +def aten_randn( + size: INT64, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TReal: """randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" shaper = op.ConstantOfShape(size) return op.RandomNormalLike(shaper, dtype=dtype) -@torch_op("aten::randn_like") -def aten_randn_like(self: TFloat) -> TFloat: - """randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - - return op.RandomNormalLike(self) - - -@torch_op("aten::randn_like") -def aten_randn_like_dtype(self: TensorType, dtype: int) -> TensorType: +@torch_op("aten::randn_like", trace_only=True) +def aten_randn_like( + self: TFloat, dtype: int = -1, layout: str = "", device: str = "", pin_memory: bool = False +) -> TFloat: """randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + if dtype == -1: + return op.RandomNormalLike(self) return op.RandomNormalLike(self, dtype=dtype) -def aten_randperm(n: int) -> TensorType: +def aten_randperm( + n: int, layout: str = "", device: str = "", pin_memory: bool = False +) -> TensorType: """randperm(int n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" raise NotImplementedError() -def aten_range(start: float, end: float) -> TensorType: +def aten_range( + start: float, end: float, layout: str = "", device: str = "", pin_memory: bool = False +) -> TensorType: """range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" raise NotImplementedError() @@ -7021,18 +7078,18 @@ def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::resolve_conj") +@torch_op("aten::resolve_conj", trace_only=True) def aten_resolve_conj(self: TTensor) -> TTensor: """resolve_conj(Tensor(a) self) -> Tensor(a)""" - return op.Identity(self) + return self -@torch_op("aten::resolve_neg") +@torch_op("aten::resolve_neg", trace_only=True) def aten_resolve_neg(self: TTensor) -> TTensor: """resolve_neg(Tensor(a) self) -> Tensor(a)""" - return op.Identity(self) + return self def aten_result_type(tensor: TensorType, other: TensorType) -> int: @@ -7241,14 +7298,14 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: return op.Reciprocal(op.Sqrt(self)) -@torch_op(("aten::rsub", "aten::rsub.Scalar")) +@torch_op(("aten::rsub.Tensor", "aten::rsub.Scalar")) def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" return op.Sub(other, op.Mul(self, alpha)) -@torch_op(("aten::rsub", "aten::rsub.Scalar"), trace_only=True, complex=True) +@torch_op(("aten::rsub.Tensor", "aten::rsub.Scalar"), trace_only=True, complex=True) def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" @@ -7259,12 +7316,13 @@ def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: def aten_scalar_tensor( s: float, dtype: int = FLOAT.dtype, - layout: str = "", # pylint: disable=unused-argument - device: str = "", # pylint: disable=unused-argument - pin_memory: bool = False, # pylint: disable=unused-argument + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> RealType: """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - + if dtype == -1: + dtype = FLOAT.dtype # Set trace_only=True because different if branches return different dtypes # which is not supported in an ONNX function return common_ops.cast_to(s, dtype=dtype) @@ -7272,12 +7330,18 @@ def aten_scalar_tensor( @torch_op("aten::scalar_tensor", trace_only=True, complex=True) def aten_scalar_tensor_complex( - s: Union[FLOAT, DOUBLE], dtype: int = COMPLEX64.dtype + s: Union[FLOAT, DOUBLE], + dtype: int = COMPLEX64.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> RealType: """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" # NOTE: When the input is originally in complex, this function is invoked. # On the other hand, when the input is originally in real, aten_scalar_tensor is used. # is invoked. + if dtype == -1: + dtype = COMPLEX64.dtype if dtype == COMPLEX128.dtype: result = op.Cast(s, to=DOUBLE.dtype) elif dtype == COMPLEX64.dtype: @@ -7290,9 +7354,16 @@ def aten_scalar_tensor_complex( @torch_op("aten::scalar_tensor", trace_only=True) -def aten_scalar_tensor_sym_number(s: RealType, dtype: int = FLOAT.dtype) -> RealType: +def aten_scalar_tensor_sym_number( + s: RealType, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> RealType: """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - + if dtype == -1: + dtype = FLOAT.dtype # Set trace_only=True because different if branches return different dtypes # which is not supported in an ONNX function return common_ops.cast_to(s, dtype=dtype) @@ -7318,7 +7389,7 @@ def aten_scatter_reduce( index: TInt, src: TReal, reduce: str, - include_self: bool = True, # pylint: disable=unused-argument + include_self: bool = True, ): """scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor""" @@ -7330,18 +7401,7 @@ def aten_scatter_reduce( "amax": "max", } onnx_reduce = reduce_mode[reduce] - return _aten_scatter_reduce_onnx(self, index, src, dim, onnx_reduce) - - -@torch_op("aten::scatter_reduce", private=True) -def _aten_scatter_reduce_onnx( - self: TReal, - index: TInt, - src: TReal, - dim: int, - onnx_reduce: str, -): - self_is_scalar = IsScalar(self) + self_is_scalar = len(self.shape) == 0 if self_is_scalar: # assert (index_rank == 0 and rank_src == 0) neg_1 = op.Constant(value_ints=[-1]) self = op.Reshape(self, neg_1) @@ -7381,7 +7441,7 @@ def aten_segment_reduce( raise NotImplementedError() -@torch_op(("aten::select", "aten::select.int")) +@torch_op("aten::select.int", traceable=True) def aten_select(self: TTensor, dim: int, index: int) -> TTensor: """select(Tensor self, int dim, int index) -> Tensor""" @@ -7461,7 +7521,7 @@ def aten_sinh(self: TFloat) -> TFloat: return op.Sinh(self) -@torch_op(("aten::slice", "aten::slice.Tensor"), trace_only=True) +@torch_op(("aten::slice.Tensor"), trace_only=True) def aten_slice( self: TTensor, dim: int = 0, @@ -7589,7 +7649,7 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), trace_only=True) +@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True) def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16: """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" @@ -7606,7 +7666,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB return result -@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), traceable=True) +@torch_op(("aten::softmax.int", "aten::special_softmax"), traceable=True) def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" @@ -7886,7 +7946,7 @@ def aten_stft( return result -@torch_op(("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub")) +@torch_op(("aten::sub.Tensor", "aten::subtract.Tensor", "_operator::sub")) def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" alpha = op.CastLike(alpha, other) @@ -7896,7 +7956,7 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: @torch_op( - ("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub"), + ("aten::sub.Tensor", "aten::subtract.Tensor", "_operator::sub"), trace_only=True, complex=True, ) @@ -8208,7 +8268,7 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType: raise NotImplementedError() -@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True) +@torch_op("aten::transpose.int", trace_only=True) def aten_transpose(self: TTensor, dim0: int, dim1: int) -> TTensor: """transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)""" @@ -8228,7 +8288,7 @@ def aten_transpose(self: TTensor, dim0: int, dim1: int) -> TTensor: return result -@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True, complex=True) +@torch_op("aten::transpose.int", trace_only=True, complex=True) def aten_transpose_complex(self: TTensor, dim0: int, dim1: int) -> TTensor: """transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)""" @@ -8323,7 +8383,7 @@ def aten_type_as(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::unbind", "aten::unbind.int")) +@torch_op("aten::unbind.int") def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" @@ -8677,7 +8737,7 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::view", "aten::_unsafe_view")) +@torch_op(("aten::view", "aten::_unsafe_view"), trace_only=True) def aten_view(self: TTensor, size: IntType) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" @@ -8702,40 +8762,40 @@ def aten_view_as(self: TTensor, other: TTensor2) -> TTensor: return op.Reshape(self, size) -@torch_op("aten::view_as_complex") +@torch_op("aten::view_as_complex", trace_only=True) def aten_view_as_complex(self: TTensor) -> TTensor: """view_as_complex(Tensor(a) self) -> Tensor(a)""" # We always operate on the real representation of a complex number in torchlib # So this is a no-op - return op.Identity(self) + return self -@torch_op("aten::view_as_complex_copy") +@torch_op("aten::view_as_complex_copy", trace_only=True) def aten_view_as_complex_copy(self: TTensor) -> TTensor: """view_as_complex_copy(Tensor self) -> Tensor""" # We always operate on the real representation of a complex number in torchlib # So this is a no-op - return op.Identity(self) + return self -@torch_op("aten::view_as_real", complex=True) +@torch_op("aten::view_as_real", complex=True, trace_only=True) def aten_view_as_real(self: TTensor) -> TTensor: """view_as_real(Tensor(a) self) -> Tensor(a)""" # We always operate on the real representation of a complex number in torchlib # So this is a no-op - return op.Identity(self) + return self -@torch_op("aten::view_as_real_copy", complex=True) +@torch_op("aten::view_as_real_copy", complex=True, trace_only=True) def aten_view_as_real_copy(self: TTensor) -> TTensor: """view_as_real_copy(Tensor self) -> Tensor""" # We always operate on the real representation of a complex number in torchlib # So this is a no-op - return op.Identity(self) + return self @torch_op("aten::view_copy") @@ -8777,10 +8837,17 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::zeros") -def aten_zeros(size: IntType, dtype: int = FLOAT.dtype): +@torch_op("aten::zeros", trace_only=True) +def aten_zeros( + size: IntType, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - + if dtype == -1: + dtype = FLOAT.dtype size = op.Cast(size, to=INT64.dtype) zero = op.Constant(value_float=0.0) zero = op.Cast(zero, to=dtype) @@ -8800,10 +8867,5 @@ def aten_zeros_like(self: TTensor, dtype: int = -1) -> TTensor: else: zero = op.Cast(0, to=dtype) - return _aten_zeros_like_onnx(self, zero) - - -@torch_op("aten::zeros_like", private=True) -def _aten_zeros_like_onnx(self: TTensor, zero) -> TTensor: shape = op.Shape(self) return op.Expand(zero, shape) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 85fc4597c..7fb06fed6 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -206,7 +206,7 @@ def aten_avg_pool2d( padding: Sequence[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: Optional[int] = None, # pylint: disable=unused-argument + divisor_override: Optional[int] = None, ) -> TFloat: """avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor""" @@ -267,7 +267,7 @@ def aten_avg_pool3d( padding: Sequence[int] = (0, 0, 0), ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: Optional[int] = None, # pylint: disable=unused-argument + divisor_override: Optional[int] = None, ) -> TFloat: """avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor""" @@ -1742,7 +1742,7 @@ def aten__scaled_dot_product_flash_attention( value: TFloat, dropout_p: float = 0.0, is_causal: bool = False, - return_debug_mask: bool = False, # pylint: disable=unused-argument + return_debug_mask: bool = False, scale: Optional[float] = None, ) -> Tuple[TFloat, FLOAT, INT64, INT64, INT64, INT64, INT64, INT64, FLOAT]: """_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) @@ -1813,12 +1813,43 @@ def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( return logsum_exp, empty_tensor_int +@torch_op("aten::_scaled_dot_product_flash_attention_for_cpu", trace_only=True) +def aten__scaled_dot_product_flash_attention_for_cpu( + query: TFloat, + key: TFloat, + value: TFloat, + dropout_p: float = 0.0, + is_causal: bool = False, + attn_mask: Optional[TFloat] = None, + scale: Optional[float] = None, +) -> Tuple[TFloat, FLOAT]: + """_scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)""" + result = aten_scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + query_shape = op.Shape(query) + query_first_dims = op.Slice(query_shape, [0], [1]) + query_second_dims = op.Slice(query_shape, [1], [2]) + num_heads = op.Slice(query_shape, [-2], [-1]) + logsumexp_dim = op.Cast( + op.Ceil(op.Cast(query_second_dims, to=FLOAT.dtype) / 32.0) * 32.0, to=INT64.dtype + ) + logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, logsumexp_dim, axis=0)) + return result, logsum_exp + + @torch_op("aten::_scaled_dot_product_efficient_attention", trace_only=True) def aten__scaled_dot_product_efficient_attention( query: TFloat, key: TFloat, value: TFloat, - attn_bias: Optional[TFloat], # pylint: disable=unused-argument + attn_bias: Optional[TFloat], compute_log_sumexp: bool, dropout_p: float = 0.0, is_causal: bool = False, diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index d61803e30..ea7b2034a 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -852,18 +852,6 @@ def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): ((S, S), {}), ((0, S, 0), {}), ((S,), {}), - ] - for shape, kwargs in inputs: - t = torch_testing.make_tensor( - shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad - ) - yield opinfo_core.SampleInput(t, **kwargs) - - -def sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): - del self # Unused - - inputs = [ ((S,), {"dtype": dtype}), # Hard-code some dtypes/devices. We want to test cases where the # (dtype, device) is different from the input's (dtype, device) @@ -1165,26 +1153,6 @@ def sample_inputs_rand_like(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(make_arg(shape)) -def sample_inputs_rand_like_dtype(op_info, device, dtype, requires_grad, **kwargs): - del op_info # Unused - del kwargs # Unused - - make_arg = functools.partial( - torch_testing.make_tensor, - device=device, - dtype=torch.float32, - requires_grad=requires_grad, - ) - shapes = ( - (M,), - (S, S), - (S, S, S), - ) - - for shape in shapes: - yield opinfo_core.SampleInput(make_arg(shape), kwargs=dict(dtype=dtype)) - - def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): high = 10 @@ -1212,14 +1180,6 @@ def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) -def sample_inputs_randint_like_dtype(self, device, dtype, requires_grad, **kwargs): - high = 10 - - for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): - # With low and high - yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) - - def sample_inputs_randint_like_low_dtype(self, device, dtype, requires_grad, **kwargs): low = 2 high = 10 @@ -1229,15 +1189,6 @@ def sample_inputs_randint_like_low_dtype(self, device, dtype, requires_grad, **k yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) -def sample_inputs_randint_like_low_dtype_dtype(self, device, dtype, requires_grad, **kwargs): - low = 2 - high = 10 - - for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): - # With low and high - yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) - - def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): del op # Unused del device # Unused @@ -2201,14 +2152,6 @@ def __init__(self): sample_inputs_func=sample_inputs_rand_like, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.rand_like__dtype", - op=torch.ops.aten.rand_like, - aten_name="rand_like", - dtypes=common_dtype.floating_types_and(torch.bfloat16), - sample_inputs_func=sample_inputs_rand_like_dtype, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.randint", aten_name="randint", @@ -2230,14 +2173,6 @@ def __init__(self): sample_inputs_func=sample_inputs_randint_like, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.randint_like__dtype", - op=torch.ops.aten.randint_like, - aten_name="randint_like", - dtypes=common_dtype.integral_types(), - sample_inputs_func=sample_inputs_randint_like_dtype, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.randint_like.low_dtype", aten_name="randint_like.low_dtype", @@ -2245,14 +2180,6 @@ def __init__(self): sample_inputs_func=sample_inputs_randint_like_low_dtype, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.randint_like.low_dtype__dtype", - op=torch.ops.aten.randint_like.low_dtype, - aten_name="randint_like.low_dtype", - dtypes=common_dtype.integral_types(), - sample_inputs_func=sample_inputs_randint_like_low_dtype_dtype, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.randn", aten_name="randn", @@ -2267,14 +2194,6 @@ def __init__(self): sample_inputs_func=sample_inputs_like_fns, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.randn_like_dtype", - op=torch.ops.aten.randn_like, - aten_name="randn", - dtypes=common_dtype.floating_types_and(torch.bfloat16), - sample_inputs_func=sample_inputs_like_fns_dtype, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.reflection_pad1d", aten_name="ops.aten.reflection_pad1d", diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index f12f9024e..4acaa7861 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -39,7 +39,6 @@ from torch.utils import _pytree as pytree import onnxscript -import onnxscript.evaluator from tests.function_libs.torch_lib import ( error_reproduction, ops_test_common, @@ -98,42 +97,14 @@ def _should_skip_xfail_test_sample( class TestFunctionValidity(unittest.TestCase): - def test_all_script_functions_are_onnx_functions(self): - for info in ops_test_data.TESTED_TORCHLIB_OPS: - if info.trace_only: - continue - with self.subTest(name=info.op_info_name): - func = info.op - if not isinstance(func, onnxscript.OnnxFunction): - raise TypeError( - f"'{func}' is not an OnnxFunction. Was it decorated with '@torch_op'? " - "If the function is trace_only, please specify trace_only=True " - "in the TorchLibOpInfo entry." - ) - - def test_all_trace_only_functions_are_not_onnx_functions(self): - for info in ops_test_data.TESTED_TORCHLIB_OPS: - if not info.trace_only: - continue - with self.subTest(name=info.op_info_name): - func = info.op - if not isinstance(func, onnxscript.TracedOnnxFunction): - raise TypeError( - f"'{func.name}' is not a TracedOnnxFunction. " - "If the function is not trace_only, please remove trace_only=True " - "in the TorchLibOpInfo entry." - ) - @parameterized.parameterized.expand( - [ - (info.op.name, info) - for info in ops_test_data.TESTED_TORCHLIB_OPS - if not info.trace_only - ] + [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] ) def test_script_function_passes_checker( self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo ): + if not isinstance(torchlib_op_info.op, onnxscript.OnnxFunction): + self.skipTest("Traced functions does not have a function proto") function_proto = torchlib_op_info.op.to_function_proto() onnx.checker.check_function(function_proto) # type: ignore[attr-defined] diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index ab3e204af..3e898c781 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -13,8 +13,7 @@ 1. To enable test cases for an operator Add a `TorchLibOpInfo` entry to `TORCH_LIB_OPINFO` in `ops_test_data.py`. - Explicitly specify `trace_only` if the op is trace_only. Specify `complex` - if the function is designed for complex inputs. + Specify `complex` if the function is designed for complex inputs. The `op_info_name` in `TorchLibOpInfo` needs to be unique in the TORCH_LIB_OPINFO list, but complex=True ops can share the same name with non-complex ops @@ -74,8 +73,6 @@ class TorchLibOpInfo: op_info_name: str # The torchlib ONNX Function to test op: Callable[..., Any] - # Explicitly specify when the op is trace_only - trace_only: bool = False # The input wrangler function to adjust the input to fit the aten signature input_wrangler: Optional[ Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]] @@ -447,14 +444,12 @@ def _where_input_wrangler( "ops.aten._fft_c2c", # Custom from extra_opinfo fft_ops.aten__fft_c2c, tolerance={torch.complex64: (3e-3, 1.8e-4)}, - trace_only=True, complex=True, ), TorchLibOpInfo( "ops.aten._fft_c2r", # Custom from extra_opinfo fft_ops.aten__fft_c2r, tolerance={torch.complex64: (3e-3, 1.8e-4)}, - trace_only=True, complex=True, ).xfail( dtypes=(torch.complex64,), @@ -464,7 +459,6 @@ def _where_input_wrangler( "ops.aten._fft_r2c", # Custom from extra_opinfo fft_ops.aten__fft_r2c, tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)}, - trace_only=True, ), TorchLibOpInfo( "ops.aten._local_scalar_dense", @@ -474,7 +468,6 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten._log_softmax_half", core_ops.aten__log_softmax_half, - trace_only=True, tolerance={torch.float16: (1e-3, 1e-3)}, ) .xfail( @@ -488,8 +481,8 @@ def _where_input_wrangler( reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", test_class_name="TestOutputConsistencyFullGraph", ), - TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax, trace_only=True), - TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True) + TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax), + TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half) .xfail( reason="PyTorch does not implement _softmax for float16 on CPU", dtypes=(torch.float16,), @@ -506,7 +499,7 @@ def _where_input_wrangler( or isinstance(sample.kwargs.get("dim"), tuple), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer", ), - TorchLibOpInfo("all_dims", core_ops.aten_all_dims, trace_only=True).skip( + TorchLibOpInfo("all_dims", core_ops.aten_all_dims).skip( matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple), reason="this overload requires dim to be a tuple", ), @@ -523,7 +516,7 @@ def _where_input_wrangler( TorchLibOpInfo("acos", core_ops.aten_acos), TorchLibOpInfo("acosh", core_ops.aten_acosh), TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}), - TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True, trace_only=True), + TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True), TorchLibOpInfo( "addbmm", core_ops.aten_addbmm, @@ -595,7 +588,7 @@ def _where_input_wrangler( or isinstance(sample.kwargs.get("dim"), tuple), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer", ), - TorchLibOpInfo("any_dims", core_ops.aten_any_dims, trace_only=True).skip( + TorchLibOpInfo("any_dims", core_ops.aten_any_dims).skip( matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple), reason="this overload requires dim to be a tuple", ), @@ -705,11 +698,11 @@ def _where_input_wrangler( TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor), TorchLibOpInfo("bmm", core_ops.aten_bmm), TorchLibOpInfo("broadcast_to", core_ops.aten_broadcast_to), - TorchLibOpInfo("cat", core_ops.aten_cat, trace_only=True).skip( + TorchLibOpInfo("cat", core_ops.aten_cat).skip( matcher=lambda sample: sample.input[0].equal(torch.tensor([])), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), - TorchLibOpInfo("cat", core_ops.aten_cat_complex, trace_only=True, complex=True).skip( + TorchLibOpInfo("cat", core_ops.aten_cat_complex, complex=True).skip( matcher=lambda sample: sample.input[0].equal(torch.tensor([])), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), @@ -738,17 +731,17 @@ def _where_input_wrangler( reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", ), TorchLibOpInfo("clone", core_ops.aten_clone), - TorchLibOpInfo("complex", core_ops.aten_complex, trace_only=True), - TorchLibOpInfo("concat", core_ops.aten_cat, trace_only=True).skip( + TorchLibOpInfo("complex", core_ops.aten_complex), + TorchLibOpInfo("concat", core_ops.aten_cat).skip( matcher=lambda sample: sample.input[0].equal(torch.tensor([])), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), - TorchLibOpInfo("concatenate", core_ops.aten_cat, trace_only=True).skip( + TorchLibOpInfo("concatenate", core_ops.aten_cat).skip( matcher=lambda sample: sample.input[0].equal(torch.tensor([])), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), TorchLibOpInfo("conj", core_ops.aten_conj), - TorchLibOpInfo("conj", core_ops.aten_conj_complex, complex=True, trace_only=True), + TorchLibOpInfo("conj", core_ops.aten_conj_complex, complex=True), TorchLibOpInfo("constant_pad_nd", core_ops.aten_constant_pad_nd), # TorchLibOpInfo("copy", core_ops.aten_copy), # copy is not in OPS_DB TorchLibOpInfo("cos", core_ops.aten_cos), @@ -756,15 +749,15 @@ def _where_input_wrangler( TorchLibOpInfo("cross", core_ops.aten_cross, tolerance={torch.float16: (6e-3, 3e-3)}), TorchLibOpInfo("deg2rad", core_ops.aten_deg2rad), # TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB - TorchLibOpInfo("diagonal", core_ops.aten_diagonal, trace_only=True), - TorchLibOpInfo("diagonal_bool", core_ops.aten_diagonal_bool, trace_only=True), + TorchLibOpInfo("diagonal", core_ops.aten_diagonal), + TorchLibOpInfo("diagonal_bool", core_ops.aten_diagonal_bool), TorchLibOpInfo("div", core_ops.aten_div).skip( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="this variation does not take the rounding_mode argument", ), TorchLibOpInfo("true_divide", core_ops.aten_div), TorchLibOpInfo("true_divide", core_ops.aten_div_complex, complex=True), - TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True) + TorchLibOpInfo("div_mode", core_ops.aten_div_mode) .skip( variant_name="no_rounding_mode", reason="this variation requires the rounding_mode argument", @@ -781,7 +774,7 @@ def _where_input_wrangler( test_class_name="TestOutputConsistencyEager", reason="fixme: off-by-one and inverted inf. https://github.com/microsoft/onnxscript/issues/989", ), - TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int, trace_only=True).skip( + TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int).skip( variant_name="no_rounding_mode", reason="this variation requires the rounding_mode argument", ), @@ -792,9 +785,7 @@ def _where_input_wrangler( input_wrangler=_empty_input_wrangler, nondeterministic=True, ), - TorchLibOpInfo( - "einsum", core_ops.aten_einsum, trace_only=True, input_wrangler=_einsum_input_wrangler - ) + TorchLibOpInfo("einsum", core_ops.aten_einsum, input_wrangler=_einsum_input_wrangler) .xfail( reason="fixme: PyTorch produces int64 output with int32 input", dtypes=(torch.int32,), @@ -828,19 +819,9 @@ def _where_input_wrangler( TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("frac", core_ops.aten_frac), TorchLibOpInfo("full", core_ops.aten_full), - TorchLibOpInfo( - "full_like_dtype", - core_ops.aten_full_like_dtype, - ).skip( - matcher=lambda sample: "dtype" not in sample.kwargs, - reason="this Aten overload only support dtype in kwargs", - ), TorchLibOpInfo( "full_like", core_ops.aten_full_like, - ).skip( - matcher=lambda sample: ("dtype" in sample.kwargs), - reason="this Aten overload only support dtype not in kwargs", ), TorchLibOpInfo("gather", core_ops.aten_gather).skip( enabled_if=not version_utils.torch_older_than("2.4"), @@ -852,8 +833,8 @@ def _where_input_wrangler( TorchLibOpInfo("gt_bool", core_ops.aten_gt_bool), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB - TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index, trace_only=True), - TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool, trace_only=True), + TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), + TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool), TorchLibOpInfo( "index_put_bool", core_ops.aten_index_put_bool, @@ -889,7 +870,6 @@ def _where_input_wrangler( TorchLibOpInfo( "linalg.vector_norm", linalg_ops.aten_linalg_vector_norm, - trace_only=True, tolerance={torch.float16: (2e-3, 2e-3)}, input_wrangler=_linalg_vector_norm_input_wrangler, ).skip( @@ -900,7 +880,6 @@ def _where_input_wrangler( TorchLibOpInfo( "linspace", core_ops.aten_linspace, - trace_only=True, tolerance={torch.float16: (2e-2, 2e-3)}, ) .xfail( @@ -921,7 +900,6 @@ def _where_input_wrangler( TorchLibOpInfo( "log_softmax", special_ops.aten_special_log_softmax, - trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (4e-4, 6e-3)}, ) .xfail( @@ -992,7 +970,7 @@ def _where_input_wrangler( reason="this Aten overload can accept 2 inputs:(self, dim)", ), TorchLibOpInfo("mH", core_ops.aten_mH), - TorchLibOpInfo("mH", core_ops.aten_mH_complex, complex=True, trace_only=True), + TorchLibOpInfo("mH", core_ops.aten_mH_complex, complex=True), TorchLibOpInfo("min_dim", core_ops.aten_min_dim) .skip( variant_name="reduction_with_dim", @@ -1041,79 +1019,27 @@ def _where_input_wrangler( TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout), TorchLibOpInfo("ne", core_ops.aten_ne), TorchLibOpInfo("neg", core_ops.aten_neg), - TorchLibOpInfo( - "new_empty_dtype", - core_ops.aten_new_empty_dtype, - nondeterministic=True, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is None, - reason="this Aten overload must have 3 inputs:(self, size, dtype)", - ), TorchLibOpInfo( "new_empty", core_ops.aten_new_empty, nondeterministic=True, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is not None, - reason="this Aten overload only accept 2 inputs:(self, size)", - ), - TorchLibOpInfo( - "new_empty_strided_dtype", - core_ops.aten_new_empty_strided_dtype, - nondeterministic=True, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is None, - reason="this Aten overload must have 4 inputs:(self, size, stride, dtype)", ), TorchLibOpInfo( "new_empty_strided", core_ops.aten_new_empty_strided, nondeterministic=True, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is not None, - reason="this Aten overload only accept 3 inputs:(self, size, stride)", - ), - TorchLibOpInfo( - "new_full_dtype", - core_ops.aten_new_full_dtype, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is None, - reason="this Aten overload must have 4 inputs:(self, size, fill_value, dtype)", ), TorchLibOpInfo( "new_full", core_ops.aten_new_full, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is not None, - reason="this Aten overload only accept 3 inputs:(self, size, fill_value)", - ), - TorchLibOpInfo( - "new_ones_dtype", - core_ops.aten_new_ones_dtype, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is None, - reason="", ), TorchLibOpInfo( "new_ones", core_ops.aten_new_ones, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is not None, - reason="", - ), - TorchLibOpInfo( - "new_zeros_dtype", - core_ops.aten_new_zeros_dtype, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is None, - reason="", ), TorchLibOpInfo( "new_zeros", core_ops.aten_new_zeros, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is not None, - reason="", ), TorchLibOpInfo( "nn.functional.adaptive_avg_pool1d", @@ -1174,13 +1100,11 @@ def _where_input_wrangler( "ops.aten.embedding_bag", core_ops.aten_embedding_bag, tolerance={torch.float16: (1e-2, 1e-2)}, - trace_only=True, compare_shape_only_for_output=(1, 2, 3), ), TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, - trace_only=True, tolerance={torch.float16: (1e-2, 1e-2)}, compare_shape_only_for_output=(1, 2, 3), ), @@ -1379,39 +1303,24 @@ def _where_input_wrangler( "permute", core_ops.aten_permute, input_wrangler=_permute_input_wrangler, - trace_only=True, ), TorchLibOpInfo("polar", core_ops.aten_polar), TorchLibOpInfo("pow", core_ops.aten_pow), TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True), TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True), - TorchLibOpInfo( - "ops.aten.rand_like__dtype", core_ops.aten_rand_like_dtype, nondeterministic=True - ), TorchLibOpInfo("ops.aten.randint", core_ops.aten_randint, nondeterministic=True), TorchLibOpInfo("ops.aten.randint.low", core_ops.aten_randint_low, nondeterministic=True), TorchLibOpInfo("ops.aten.randint_like", core_ops.aten_randint_like, nondeterministic=True), - TorchLibOpInfo( - "ops.aten.randint_like__dtype", core_ops.aten_randint_like_dtype, nondeterministic=True - ), TorchLibOpInfo( "ops.aten.randint_like.low_dtype", core_ops.aten_randint_like_low_dtype, nondeterministic=True, ), - TorchLibOpInfo( - "ops.aten.randint_like.low_dtype__dtype", - core_ops.aten_randint_like_low_dtype_dtype, - nondeterministic=True, - ), TorchLibOpInfo("ops.aten.randn", core_ops.aten_randn, nondeterministic=True).xfail( dtypes=(torch.float16,), reason="fixme: Shape inference error", ), TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn_like, nondeterministic=True), - TorchLibOpInfo( - "ops.aten.randn_like_dtype", core_ops.aten_randn_like_dtype, nondeterministic=True - ), TorchLibOpInfo("rad2deg", core_ops.aten_rad2deg), TorchLibOpInfo("reciprocal", core_ops.aten_reciprocal), TorchLibOpInfo( @@ -1443,24 +1352,21 @@ def _where_input_wrangler( TorchLibOpInfo("round_decimals", core_ops.aten_round_decimals), TorchLibOpInfo("rsqrt", core_ops.aten_rsqrt), TorchLibOpInfo("rsub", core_ops.aten_rsub), - TorchLibOpInfo("rsub", core_ops.aten_rsub_complex, complex=True, trace_only=True), + TorchLibOpInfo("rsub", core_ops.aten_rsub_complex, complex=True), TorchLibOpInfo( "scalar_tensor", core_ops.aten_scalar_tensor, input_wrangler=_scalar_tensor_input_wrangler, - trace_only=True, ), TorchLibOpInfo( "scalar_tensor", core_ops.aten_scalar_tensor, input_wrangler=_scalar_tensor_input_wrangler, - trace_only=True, complex=True, ), TorchLibOpInfo( "ops.aten.scalar_tensor", core_ops.aten_scalar_tensor_complex, - trace_only=True, complex=True, ), TorchLibOpInfo( @@ -1487,7 +1393,6 @@ def _where_input_wrangler( TorchLibOpInfo( "softmax", core_ops.aten_softmax, - trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (3e-4, 4e-4)}, ) .xfail( @@ -1564,7 +1469,6 @@ def _where_input_wrangler( "squeeze_dim", core_ops.aten_squeeze_dim_complex, complex=True, - trace_only=True, ).skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", @@ -1577,9 +1481,9 @@ def _where_input_wrangler( reason="this Aten overload only support one tensor as input by design", ), TorchLibOpInfo("stack", core_ops.aten_stack), - TorchLibOpInfo("stack", core_ops.aten_stack_complex, complex=True, trace_only=True), + TorchLibOpInfo("stack", core_ops.aten_stack_complex, complex=True), TorchLibOpInfo("sub", core_ops.aten_sub), - TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True, trace_only=True), + TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB TorchLibOpInfo( "t", @@ -1634,8 +1538,8 @@ def _where_input_wrangler( matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", ), - TorchLibOpInfo("unfold", core_ops.aten_unfold, trace_only=True), - TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold, trace_only=True), + TorchLibOpInfo("unfold", core_ops.aten_unfold), + TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold), TorchLibOpInfo("unsqueeze", core_ops.aten_unsqueeze), TorchLibOpInfo("view", core_ops.aten_view), TorchLibOpInfo("view", core_ops.aten_view_complex, complex=True), @@ -1661,7 +1565,6 @@ def _where_input_wrangler( TorchLibOpInfo( "arange_start_step", core_ops.aten_arange_start_step, - trace_only=True, ).xfail( matcher=lambda sample: len(sample.args) != 2, reason="arange_start_step overload takes three arguments (input, start, step)", @@ -1669,7 +1572,6 @@ def _where_input_wrangler( TorchLibOpInfo( "arange_start", core_ops.aten_arange_start, - trace_only=True, ).skip( matcher=lambda sample: len(sample.args) != 1, reason="arange_start overload takes two arguments (input, start)", @@ -1677,7 +1579,6 @@ def _where_input_wrangler( TorchLibOpInfo( "arange", core_ops.aten_arange, - trace_only=True, ) .xfail( dtypes=(torch.int32,), @@ -1691,7 +1592,7 @@ def _where_input_wrangler( matcher=lambda sample: sample.kwargs.get("end") is not None, reason="arange overload does not support positional 'end' argument", ), - TorchLibOpInfo("argmax", core_ops.aten_argmax, trace_only=True) + TorchLibOpInfo("argmax", core_ops.aten_argmax) .skip( matcher=lambda sample: len(sample.input.shape) == 0, enabled_if=version_utils.onnxruntime_older_than("1.16"), @@ -1701,7 +1602,7 @@ def _where_input_wrangler( dtypes=(torch.int64,), reason="fixme: ORT did not implement ArgMax for int64. https://github.com/microsoft/onnxruntime/issues/16654", ), - TorchLibOpInfo("argmin", core_ops.aten_argmin, trace_only=True) + TorchLibOpInfo("argmin", core_ops.aten_argmin) .skip( matcher=lambda sample: len(sample.input.shape) == 0, enabled_if=version_utils.onnxruntime_older_than("1.16"), @@ -1714,12 +1615,11 @@ def _where_input_wrangler( TorchLibOpInfo( "as_strided", core_ops.aten_as_strided, - trace_only=True, ).xfail( variant_name="partial_views", reason="ONNX doesn't have partial view for tensor", ), - TorchLibOpInfo("clamp", core_ops.aten_clamp, trace_only=True).skip( + TorchLibOpInfo("clamp", core_ops.aten_clamp).skip( matcher=lambda sample: len(sample.input.shape) == 0, enabled_if=version_utils.onnxruntime_older_than("1.16"), reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", @@ -1727,12 +1627,11 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.col2im", nn_ops.aten_col2im, - trace_only=True, ).xfail( dtypes=(torch.float16,), reason="fixme: Tensor-likes are not close. https://github.com/microsoft/onnxruntime/issues/16007", ), - TorchLibOpInfo("cumsum", core_ops.aten_cumsum, trace_only=True).xfail( + TorchLibOpInfo("cumsum", core_ops.aten_cumsum).xfail( dtypes=(torch.int32,), reason="fixme: torch.cumsum with int32 inputs uses int64 as the output type", ), @@ -1740,16 +1639,12 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.convolution", core_ops.aten_convolution, - trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4)}, ), - TorchLibOpInfo( - "empty_like", core_ops.aten_empty_like, nondeterministic=True, trace_only=True - ), + TorchLibOpInfo("empty_like", core_ops.aten_empty_like, nondeterministic=True), TorchLibOpInfo( "grid_sampler_2d", core_ops.aten_grid_sampler_2d, - trace_only=True, ).skip( # Torch implemented this using the cubic convolution algorithm with alhpa=-0.75, might be different than ORT matcher=lambda sample: sample.args[1] == 2, @@ -1767,7 +1662,6 @@ def _where_input_wrangler( "nn.functional.grid_sample", core_ops.aten_grid_sampler, input_wrangler=_grid_sample_input_wrangler, - trace_only=True, ).skip( # Torch implemented this using the cubic convolution algorithm with alhpa=-0.75, might be different than ORT matcher=lambda sample: sample.kwargs.get("mode") == "bicubic" @@ -1777,15 +1671,12 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.layer_norm", core_ops.aten_layer_norm, - trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4)}, ).xfail( dtypes=(torch.int64,), reason="fixme: ORT `LayerNormKernelImpl` not implemented for int64", ), - TorchLibOpInfo( - "logit", core_ops.aten_logit, trace_only=True, tolerance={torch.float16: (1e-1, 7e-4)} - ), + TorchLibOpInfo("logit", core_ops.aten_logit, tolerance={torch.float16: (1e-1, 7e-4)}), TorchLibOpInfo("max_dim", core_ops.aten_max_dim) .skip( variant_name="reduction_with_dim", @@ -1821,18 +1712,15 @@ def _where_input_wrangler( # Custom from extra_opinfo "ops.aten.max_pool1d", nn_ops.aten_max_pool1d, - trace_only=True, ), TorchLibOpInfo( # Custom from extra_opinfo "ops.aten.max_pool2d", nn_ops.aten_max_pool2d, - trace_only=True, ), TorchLibOpInfo( "ops.aten.max_pool3d", # Custom from extra_opinfo nn_ops.aten_max_pool3d, - trace_only=True, ).xfail( variant_name="empty_strides", reason="fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975", @@ -1840,7 +1728,6 @@ def _where_input_wrangler( TorchLibOpInfo( "native_batch_norm", core_ops.aten_native_batch_norm, - trace_only=True, tolerance={torch.float16: (1e-2, 7e-3)}, ) .skip( @@ -1856,7 +1743,6 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten._native_batch_norm_legit", core_ops.aten_native_batch_norm, - trace_only=True, tolerance={torch.float16: (1e-2, 7e-3)}, ).skip( device_type="cpu", @@ -1866,12 +1752,10 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten._native_batch_norm_legit.no_stats", core_ops.aten__native_batch_norm_no_stats, - trace_only=True, ), TorchLibOpInfo( "ops.aten._native_batch_norm_legit_functional", core_ops.aten__native_batch_norm_legit_functional, - trace_only=True, tolerance={torch.float16: (1e-2, 7e-3)}, ) .skip( @@ -1889,7 +1773,6 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.native_group_norm", core_ops.aten_native_group_norm, - trace_only=True, tolerance={torch.float16: (1e-2, 7e-3)}, ).xfail( dtypes=(torch.float16,), @@ -1899,7 +1782,6 @@ def _where_input_wrangler( TorchLibOpInfo( "native_layer_norm", core_ops.aten_native_layer_norm, - trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (1e-1, 7e-4)}, ) .xfail( @@ -1917,7 +1799,6 @@ def _where_input_wrangler( "nn.functional.avg_pool1d", nn_ops.aten_avg_pool1d, input_wrangler=_avg_pool_input_wrangler, - trace_only=True, ) .xfail( matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) @@ -1938,7 +1819,6 @@ def _where_input_wrangler( "nn.functional.avg_pool2d", nn_ops.aten_avg_pool2d, input_wrangler=_avg_pool_input_wrangler, - trace_only=True, ).xfail( matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) or (sample.kwargs.get("divisor_override") is not None), @@ -1948,7 +1828,6 @@ def _where_input_wrangler( "nn.functional.avg_pool3d", nn_ops.aten_avg_pool3d, input_wrangler=_avg_pool_input_wrangler, - trace_only=True, ) .xfail( matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) @@ -1962,7 +1841,6 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.conv1d", core_ops.aten_conv1d, - trace_only=True, ).xfail( matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str), reason="String padding is not accepted by aten::conv1d", @@ -1970,7 +1848,6 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.conv2d", core_ops.aten_conv2d, - trace_only=True, tolerance={torch.float32: (2e-5, 3e-5)}, ).xfail( matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str), @@ -1979,19 +1856,16 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.instance_norm", core_ops.aten_instance_norm, - trace_only=True, tolerance={torch.float16: (1e-2, 1e-3)}, ), TorchLibOpInfo( "ops.aten.conv3d", core_ops.aten_conv3d, - trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4)}, ), TorchLibOpInfo( "nn.functional.gelu", nn_ops.aten_gelu, - trace_only=True, tolerance={torch.float16: (8e-2, 1e-4)}, ), TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear).skip( @@ -2012,7 +1886,6 @@ def _where_input_wrangler( "nn.functional.max_pool1d", nn_ops.aten_max_pool1d, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ).skip( matcher=lambda sample: sample.kwargs.get("return_indices") is True, reason="this aten overload assume return_indices=False", @@ -2021,7 +1894,6 @@ def _where_input_wrangler( "nn.functional.max_pool1d_with_indices", nn_ops.aten_max_pool1d_with_indices, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ).skip( matcher=lambda sample: sample.kwargs.get("return_indices") is False, reason="this aten overload assume return_indices=True", @@ -2030,7 +1902,6 @@ def _where_input_wrangler( "nn.functional.max_pool2d", nn_ops.aten_max_pool2d, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ).skip( matcher=lambda sample: sample.kwargs.get("return_indices") is True, reason="this aten overload assume return_indices=False", @@ -2039,7 +1910,6 @@ def _where_input_wrangler( "nn.functional.max_pool2d_with_indices", nn_ops.aten_max_pool2d_with_indices, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ).skip( matcher=lambda sample: sample.kwargs.get("return_indices") is False, reason="this aten overload assume return_indices=True", @@ -2048,7 +1918,6 @@ def _where_input_wrangler( "nn.functional.max_pool3d", nn_ops.aten_max_pool3d, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ) .skip( matcher=lambda sample: sample.kwargs.get("ceil_mode") is True @@ -2063,7 +1932,6 @@ def _where_input_wrangler( "nn.functional.max_pool3d_with_indices", nn_ops.aten_max_pool3d_with_indices, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ) .skip( matcher=lambda sample: sample.kwargs.get("ceil_mode") is True @@ -2077,7 +1945,6 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.scaled_dot_product_attention", nn_ops.aten_scaled_dot_product_attention, - trace_only=True, tolerance={torch.float32: (3e-4, 1.5e-5)}, ) .skip( @@ -2102,7 +1969,6 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten._scaled_dot_product_flash_attention", nn_ops.aten__scaled_dot_product_flash_attention, - trace_only=True, tolerance={torch.float32: (3e-4, 1.5e-5)}, # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, @@ -2119,7 +1985,6 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten._scaled_dot_product_efficient_attention", nn_ops.aten__scaled_dot_product_efficient_attention, - trace_only=True, tolerance={torch.float32: (3e-4, 1.5e-5)}, # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, @@ -2136,7 +2001,6 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.scaled_dot_product_attention_bool_mask", nn_ops.aten_scaled_dot_product_attention_bool_mask, - trace_only=True, tolerance={torch.float32: (3e-4, 1.5e-5)}, ) .skip( @@ -2161,7 +2025,6 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.upsample_bilinear2d.default", nn_ops.aten_upsample_bilinear2d, - trace_only=True, ).xfail( matcher=lambda sample: sample.args[1] is False and sample.kwargs.get("scales_h") is not None, @@ -2170,12 +2033,10 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.upsample_bilinear2d.vec", nn_ops.aten_upsample_bilinear2d_vec, - trace_only=True, ), TorchLibOpInfo( "ops.aten.upsample_bicubic2d.default", nn_ops.aten_upsample_bicubic2d, - trace_only=True, ).xfail( matcher=lambda sample: sample.args[1] is False and sample.kwargs.get("scales_h") is not None, @@ -2184,12 +2045,10 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.upsample_bicubic2d.vec", nn_ops.aten_upsample_bicubic2d_vec, - trace_only=True, ), TorchLibOpInfo( "ops.aten.upsample_linear1d", nn_ops.aten_upsample_linear1d, - trace_only=True, ).xfail( matcher=lambda sample: sample.args[1] is False and sample.kwargs.get("scales") is not None, @@ -2198,47 +2057,39 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.upsample_nearest1d", nn_ops.aten_upsample_nearest1d, - trace_only=True, ), TorchLibOpInfo( "ops.aten.upsample_nearest2d", nn_ops.aten_upsample_nearest2d, - trace_only=True, ), TorchLibOpInfo( "ops.aten.upsample_nearest3d", nn_ops.aten_upsample_nearest3d, - trace_only=True, ), TorchLibOpInfo( "ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d, - trace_only=True, ), TorchLibOpInfo( "ops.aten.upsample_trilinear3d.vec", nn_ops.aten_upsample_trilinear3d_vec, - trace_only=True, ), - TorchLibOpInfo("ones_like", core_ops.aten_ones_like, trace_only=True), + TorchLibOpInfo("ones_like", core_ops.aten_ones_like), TorchLibOpInfo( "roll", core_ops.aten_roll, - trace_only=True, input_wrangler=_roll_input_wrangler, ), TorchLibOpInfo( "roll", core_ops.aten_roll_complex, input_wrangler=_roll_input_wrangler, - trace_only=True, complex=True, ), TorchLibOpInfo( "scatter_reduce", core_ops.aten_scatter_reduce, input_wrangler=_scatter_reduce_input_wrangler, - trace_only=True, ) .xfail( variant_name="mean", @@ -2265,12 +2116,11 @@ def _where_input_wrangler( variant_name="sum", reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", ), - TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter, trace_only=True), - TorchLibOpInfo("slice", core_ops.aten_slice, trace_only=True), + TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), + TorchLibOpInfo("slice", core_ops.aten_slice), TorchLibOpInfo( "ops.aten.stft", # Custom from extra_opinfo core_ops.aten_stft, - trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4)}, ).xfail( dtypes=(torch.float16,), @@ -2280,7 +2130,6 @@ def _where_input_wrangler( "sum", core_ops.aten_sum_dim_IntList, input_wrangler=_sum_input_wrangler, - trace_only=True, ).xfail( dtypes=(torch.int32,), reason="fixme: torch.sum uses int64 as the accumulator for int32 inputs", @@ -2295,14 +2144,11 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.tensor.int", core_ops.aten_tensor_int ), # Custom from extra_opinfo - TorchLibOpInfo("transpose", core_ops.aten_transpose, trace_only=True), - TorchLibOpInfo( - "transpose", core_ops.aten_transpose_complex, trace_only=True, complex=True - ), + TorchLibOpInfo("transpose", core_ops.aten_transpose), + TorchLibOpInfo("transpose", core_ops.aten_transpose_complex, complex=True), TorchLibOpInfo( "var_mean", core_ops.aten_var_mean, - trace_only=True, ).xfail( # kwargs is empty matcher=lambda sample: len(sample.kwargs) > 0, @@ -2311,7 +2157,6 @@ def _where_input_wrangler( TorchLibOpInfo( "var_mean_dim", core_ops.aten_var_mean_dim, - trace_only=True, ).xfail( # kwargs["dim"] must exist, kwargs["correction"] must not exist matcher=lambda sample: not ( @@ -2323,7 +2168,6 @@ def _where_input_wrangler( TorchLibOpInfo( "var_mean_correction", core_ops.aten_var_mean_correction, - trace_only=True, ).skip( # Don't accept input[1]=bool and 'correction' must be in kwargs matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, @@ -2332,7 +2176,6 @@ def _where_input_wrangler( TorchLibOpInfo( "var", core_ops.aten_var, - trace_only=True, ).xfail( # kwargs must be empty matcher=lambda sample: len(sample.kwargs) > 0, @@ -2341,7 +2184,6 @@ def _where_input_wrangler( TorchLibOpInfo( "var_dim", core_ops.aten_var_dim, - trace_only=True, ).xfail( # kwargs["dim"] must exist, kwargs["correction"] must not exist matcher=lambda sample: not ( @@ -2353,13 +2195,12 @@ def _where_input_wrangler( TorchLibOpInfo( "var_correction", core_ops.aten_var_correction, - trace_only=True, ).skip( # Don't accept input[1]=bool and 'correction' must be in kwargs matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, reason="this Aten overload only support when correction attribute exists", ), - TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True), + TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like), TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms), ) @@ -2393,7 +2234,6 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) ops_test_common.duplicate_opinfo(OPS_DB, "diagonal", ("diagonal_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode", "div_mode_int")) -ops_test_common.duplicate_opinfo(OPS_DB, "full_like", ("full_like_dtype",)) ops_test_common.duplicate_opinfo(OPS_DB, "ge", ("ge_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "gt", ("gt_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) @@ -2404,11 +2244,6 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "new_empty", ("new_empty_dtype",)) -ops_test_common.duplicate_opinfo(OPS_DB, "new_empty_strided", ("new_empty_strided_dtype",)) -ops_test_common.duplicate_opinfo(OPS_DB, "new_full", ("new_full_dtype",)) -ops_test_common.duplicate_opinfo(OPS_DB, "new_ones", ("new_ones_dtype",)) -ops_test_common.duplicate_opinfo(OPS_DB, "new_zeros", ("new_zeros_dtype",)) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.linear", ("nn.functional.linear_bias",) )