Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use devices/dtypes based on passed in tensors #68

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

adgilbert
Copy link

Hi,

Thanks for putting together this repository. I'm using the loss functions only as a part of another project and using CPU/GPU at different times. I also use half precision training sometimes. This PR makes the loss functions use the device/dtypes of the passed in tensors rather than always using GPU/torch.float32. Since the training code still uses get_torch_device() and float32 tensors this should change the operation only when someone is using the loss functions separately (my use case)

Creating this PR in case it's useful to others. Obviously feel free to reject if you want to keep as is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant