diff --git a/trajnetbaselines/lstm/gridbased_pooling.py b/trajnetbaselines/lstm/gridbased_pooling.py index e1d89ed..30a1450 100644 --- a/trajnetbaselines/lstm/gridbased_pooling.py +++ b/trajnetbaselines/lstm/gridbased_pooling.py @@ -250,7 +250,7 @@ def occupancy(self, obs, other_values=None, past_obs=None): ## if only primary pedestrian present if num_tracks == 1: - return self.constant*torch.ones(1, self.pooling_dim, self.n, self.n, device=obs.device) + return self.constant*torch.ones(batch_size, self.pooling_dim, self.n, self.n, device=obs.device) ## Get relative position ## [batch_size, num_tracks, 2] --> [batch_size, num_tracks, num_tracks, 2]