Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute DFT from timeseries data wrong output shape #1638

Open
cstiewegit opened this issue Jun 19, 2024 · 0 comments
Open

Compute DFT from timeseries data wrong output shape #1638

cstiewegit opened this issue Jun 19, 2024 · 0 comments
Labels
question Further information is requested

Comments

@cstiewegit
Copy link

behavior

I created a model to compute fft-coefficients of an onedimensional timeseries.
The output of the DFT-Operator is 3-dimensional, from these the (complex) coeffitients are gathered.
From them I want to compute the absolute value of these complex numbers.

The output in onnxscript eager mode works fine, but inferencing with onnxruntime throws an error:

Fail: [ONNXRuntimeError] : 1 : FAIL : Node (n11) Op (Gather) [ShapeInferenceError] axis must be in [-r, r-1]

After shape inferencing the model looks like this:
test

It looks like the output shape of the DFT-Operator is wrong?
Am I missing something?

code to reproduce

from onnxscript import FLOAT, script
from onnxscript import opset20 as op 
from onnxruntime import InferenceSession
import numpy as np
import onnx
    
@script(ir_version=9)
def onnx_dft_abs(X: FLOAT[10]) -> FLOAT[None]:
    X_as_column = op.Reshape(data=X, shape=[-1, 1])
    dft = op.DFT(input=X_as_column, onesided=1, axis=0, dft_length=10)
    # complex coefficients
    dft_complex = dft[:, 0, :]
    # compute absolute value
    squared = op.Mul(dft_complex, dft_complex)
    abs_val = op.Sqrt(op.Sum(squared[:, 0], squared[:, 1]))
    return abs_val

x = np.random.randn(10).astype(np.float32)

model = onnx_dft_abs.to_model_proto()
model = onnx.shape_inference.infer_shapes(model)
sess = InferenceSession(model.SerializeToString())
input_names = sess.get_inputs()[0].name
result = sess.run(None, {sess.get_inputs()[0].name: x})[0]
@justinchuby justinchuby added the question Further information is requested label Jun 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants