Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix train dtype for native_dropout | fix(torchlib) #869

Merged
merged 4 commits into from
Jul 13, 2023

Conversation

justinchuby
Copy link
Collaborator

Previously the train attribute in native_dropout was not casted to BOOL. We need to do that because the underlying attribute type for python bool is INT64 when promoted to input. This change also added a test that will catch the old error.

Previously the `train` attribute in native_dropout was not casted to BOOL. We need to do that because the underlying attribute type for python bool is INT64 when promoted to input. This change also added a test that will catch the old error.
@justinchuby justinchuby added the topic: torch_lib Related to the torch/aten function lib in development label Jul 13, 2023
@codecov
Copy link

codecov bot commented Jul 13, 2023

Codecov Report

Merging #869 (a56db2d) into main (1fc87c3) will increase coverage by 0.02%.
The diff coverage is 84.61%.

@@            Coverage Diff             @@
##             main     #869      +/-   ##
==========================================
+ Coverage   76.55%   76.57%   +0.02%     
==========================================
  Files         112      112              
  Lines       13394    13407      +13     
  Branches     1347     1350       +3     
==========================================
+ Hits        10254    10267      +13     
+ Misses       2806     2805       -1     
- Partials      334      335       +1     
Impacted Files Coverage Δ
...ipt/tests/function_libs/torch_lib/ops_test_data.py 96.72% <ø> (ø)
...ript/tests/function_libs/torch_lib/extra_opinfo.py 97.34% <83.33%> (-0.96%) ⬇️
onnxscript/function_libs/torch_lib/ops/core.py 77.00% <100.00%> (+0.09%) ⬆️

Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it like the difference between using it as an attribute or an input? Looks like it's going to be used as input of op.Dropout, so we need to promote it to BOOL.

@@ -4609,6 +4609,10 @@ def aten_native_dropout(
) -> Tuple[TFloatOrBFloat16, BOOL]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is for 1 line above, train: bool = True

I encountered another op with this recently too. It seems train is created as Constant w/ INT type. Is it possible to fix the root to have it create as BOOL type Constant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need to update the translator. Created #872

@justinchuby
Copy link
Collaborator Author

Is it like the difference between using it as an attribute or an input? Looks like it's going to be used as input of op.Dropout, so we need to promote it to BOOL.

That's correct.

@justinchuby justinchuby merged commit 97604f6 into main Jul 13, 2023
30 of 33 checks passed
@justinchuby justinchuby deleted the justinchu/fix-native-dropout branch July 13, 2023 16:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants