-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix train
dtype for native_dropout | fix(torchlib)
#869
Conversation
Previously the `train` attribute in native_dropout was not casted to BOOL. We need to do that because the underlying attribute type for python bool is INT64 when promoted to input. This change also added a test that will catch the old error.
Codecov Report
@@ Coverage Diff @@
## main #869 +/- ##
==========================================
+ Coverage 76.55% 76.57% +0.02%
==========================================
Files 112 112
Lines 13394 13407 +13
Branches 1347 1350 +3
==========================================
+ Hits 10254 10267 +13
+ Misses 2806 2805 -1
- Partials 334 335 +1
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it like the difference between using it as an attribute or an input? Looks like it's going to be used as input of op.Dropout, so we need to promote it to BOOL.
@@ -4609,6 +4609,10 @@ def aten_native_dropout( | |||
) -> Tuple[TFloatOrBFloat16, BOOL]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is for 1 line above, train: bool = True
I encountered another op with this recently too. It seems train
is created as Constant w/ INT type. Is it possible to fix the root to have it create as BOOL type Constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we will need to update the translator. Created #872
That's correct. |
Previously the
train
attribute in native_dropout was not casted to BOOL. We need to do that because the underlying attribute type for python bool is INT64 when promoted to input. This change also added a test that will catch the old error.