Skip to content

Commit

Permalink
Modified polar component operators so they work on tensors.
Browse files Browse the repository at this point in the history
  • Loading branch information
lecoanet committed Jun 30, 2024
1 parent 51ca757 commit 8c92ec6
Showing 1 changed file with 30 additions and 38 deletions.
68 changes: 30 additions & 38 deletions dedalus/core/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5974,28 +5974,24 @@ class PolarAzimuthalComponent(operators.AzimuthalComponent):
basis_type = IntervalBasis

def subproblem_matrix(self, subproblem):
# I'm not sure how to generalize this to higher order tensors, since we do
# not have spin_weights for the S1 basis.
matrix = np.array([[1,0]])
operand = self.args[0]
input_dim = len(operand.tensorsig)
output_dim = len(self.tensorsig)
matrix = []
for output in range(2**output_dim):
index_out = np.unravel_index(output, [2]*output_dim)
matrix_row = []
for input in range(2**input_dim):
index_in = np.unravel_index(input, [2]*input_dim)
if tuple(index_in[:self.index] + index_in[self.index+1:]) == index_out and index_in[self.index] == 0:
matrix_row.append(1)
else:
matrix_row.append(0)
matrix.append(matrix_row)
matrix = np.array(matrix)
if self.dtype == np.float64:
# Block-diag for sin/cos parts for real dtype
matrix = sparse.kron(matrix, sparse.eye(2))

# operand = self.args[0]
# basis = self.domain.get_basis(self.coordsys)
# S_in = basis.spin_weights(operand.tensorsig)
# S_out = basis.spin_weights(self.tensorsig)
#
# matrix = []
# for spinindex_out, spintotal_out in np.ndenumerate(S_out):
# matrix_row = []
# for spinindex_in, spintotal_in in np.ndenumerate(S_in):
# if tuple(spinindex_in[:self.index] + spinindex_in[self.index+1:]) == spinindex_out and spinindex_in[self.index] == 2:
# matrix_row.append( 1 )
# else:
# matrix_row.append( 0 )
# matrix.append(matrix_row)
# matrix = np.array(matrix)
return matrix

def operate(self, out):
Expand All @@ -6012,28 +6008,24 @@ class PolarRadialComponent(operators.RadialComponent):
basis_type = IntervalBasis

def subproblem_matrix(self, subproblem):
# I'm not sure how to generalize this to higher order tensors, since we do
# not have spin_weights for the S1 basis.
matrix = np.array([[0,1]])
operand = self.args[0]
input_dim = len(operand.tensorsig)
output_dim = len(self.tensorsig)
matrix = []
for output in range(2**output_dim):
index_out = np.unravel_index(output, [2]*output_dim)
matrix_row = []
for input in range(2**input_dim):
index_in = np.unravel_index(input, [2]*input_dim)
if tuple(index_in[:self.index] + index_in[self.index+1:]) == index_out and index_in[self.index] == 1:
matrix_row.append(1)
else:
matrix_row.append(0)
matrix.append(matrix_row)
matrix = np.array(matrix)
if self.dtype == np.float64:
# Block-diag for sin/cos parts for real dtype
matrix = sparse.kron(matrix, sparse.eye(2))

# operand = self.args[0]
# basis = self.domain.get_basis(self.coordsys)
# S_in = basis.spin_weights(operand.tensorsig)
# S_out = basis.spin_weights(self.tensorsig)
#
# matrix = []
# for spinindex_out, spintotal_out in np.ndenumerate(S_out):
# matrix_row = []
# for spinindex_in, spintotal_in in np.ndenumerate(S_in):
# if tuple(spinindex_in[:self.index] + spinindex_in[self.index+1:]) == spinindex_out and spinindex_in[self.index] == 2:
# matrix_row.append( 1 )
# else:
# matrix_row.append( 0 )
# matrix.append(matrix_row)
# matrix = np.array(matrix)
return matrix

def operate(self, out):
Expand Down

0 comments on commit 8c92ec6

Please sign in to comment.