Skip to content

Commit

Permalink
Fix train dtype for native_dropout | fix(torchlib) (#869)
Browse files Browse the repository at this point in the history
Previously the `train` attribute in native_dropout was not casted to
BOOL. We need to do that because the underlying attribute type for
python bool is INT64 when promoted to input. This change also added a
test that will catch the old error.
  • Loading branch information
justinchuby committed Jul 13, 2023
1 parent 1fc87c3 commit 97604f6
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
4 changes: 4 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4609,6 +4609,10 @@ def aten_native_dropout(
) -> Tuple[TFloatOrBFloat16, BOOL]:
"""native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""

# Python bool attributes need to be explicitly converted to BOOL
# because the underlying attribute type is int
# TODO(#872): Allow ONNX Script to handle this conversion
train = op.Cast(train, to=BOOL.dtype)
result, mask = op.Dropout(input, p, train)
return result, mask

Expand Down
33 changes: 33 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from torch.testing._internal.opinfo import core as opinfo_core

S = 5


def sample_inputs__local_scalar_dense(op_info, device, dtype, requires_grad, **kwargs):
del op_info
Expand Down Expand Up @@ -446,6 +448,28 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs):
yield opinfo_core.SampleInput(tensor, args=(output_size, kernel_size), kwargs=kwargs)


def sample_inputs_native_dropout(
op_info, device, dtype, requires_grad, *, valid_input_dim=None, **kwargs
):
del op_info # Unused
del kwargs # Unused
make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)

if valid_input_dim:
cases = ((S,) * i for i in valid_input_dim)
else:
cases = ((S, S), (S,), ())
# ONNX requires 0 <= p < 1
p_vals = [0.0]

training_vals = [True, False]

for case, p, training in itertools.product(cases, p_vals, training_vals):
yield opinfo_core.SampleInput(make_arg(case), p=p, train=training)


def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
Expand Down Expand Up @@ -561,6 +585,15 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
dtypes=common_dtype.all_types(),
sample_inputs_func=sample_inputs__local_scalar_dense,
),
opinfo_core.OpInfo(
"ops.aten.native_dropout",
aten_name="native_dropout",
dtypes=common_dtype.all_types_and_half(),
sample_inputs_func=sample_inputs_native_dropout,
supports_out=False,
),
# TODO(#870): Rename all with the op.aten. prefix so OpInfo will find
# the op automatically.
opinfo_core.OpInfo(
"col2im",
op=torch.ops.aten.col2im,
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def _where_input_wrangler(
TorchLibOpInfo("mm", core_ops.aten_mm),
TorchLibOpInfo("mul", core_ops.aten_mul),
TorchLibOpInfo("narrow", core_ops.aten_narrow),
# TorchLibOpInfo("native_dropout", core_ops.aten_native_dropout), # native_dropout is not in OPS_DB
TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout),
TorchLibOpInfo("ne", core_ops.aten_ne),
TorchLibOpInfo("neg", core_ops.aten_neg),
TorchLibOpInfo(
Expand Down

0 comments on commit 97604f6

Please sign in to comment.