From 9f26afffe5c310f8b9d520f037e4f583613e2ca0 Mon Sep 17 00:00:00 2001 From: Robert Nagy Date: Tue, 7 May 2024 10:33:15 +0200 Subject: [PATCH] fix: request abort signal (#3209) * fix: request abort signal * fixup * fixup * fixup --- lib/api/api-request.js | 36 ++++++++++++++++---- test/issue-2590.js | 4 +-- test/request-signal.js | 76 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 8 deletions(-) create mode 100644 test/request-signal.js diff --git a/lib/api/api-request.js b/lib/api/api-request.js index e5d598aa6dd..f70f351f2dc 100644 --- a/lib/api/api-request.js +++ b/lib/api/api-request.js @@ -2,11 +2,10 @@ const assert = require('node:assert') const { Readable } = require('./readable') -const { InvalidArgumentError } = require('../core/errors') +const { InvalidArgumentError, RequestAbortedError } = require('../core/errors') const util = require('../core/util') const { getResolveErrorBodyCallback } = require('./util') const { AsyncResource } = require('node:async_hooks') -const { addSignal, removeSignal } = require('./abort-signal') class RequestHandler extends AsyncResource { constructor (opts, callback) { @@ -56,6 +55,9 @@ class RequestHandler extends AsyncResource { this.onInfo = onInfo || null this.throwOnError = throwOnError this.highWaterMark = highWaterMark + this.signal = signal + this.reason = null + this.removeAbortListener = null if (util.isStream(body)) { body.on('error', (err) => { @@ -63,7 +65,23 @@ class RequestHandler extends AsyncResource { }) } - addSignal(this, signal) + if (this.signal) { + if (this.signal.aborted) { + this.reason = this.signal.reason ?? new RequestAbortedError() + } else { + this.removeAbortListener = util.addAbortListener(this.signal, () => { + this.removeAbortListener?.() + this.removeAbortListener = null + + this.reason = this.signal.reason ?? new RequestAbortedError() + if (this.res) { + util.destroy(this.res, this.reason) + } else if (this.abort) { + this.abort(this.reason) + } + }) + } + } } onConnect (abort, context) { @@ -95,6 +113,13 @@ class RequestHandler extends AsyncResource { const contentLength = parsedHeaders['content-length'] const body = new Readable({ resume, abort, contentType, contentLength, highWaterMark }) + if (this.removeAbortListener) { + // TODO (fix): 'close' is sufficient but breaks tests. + body + .on('end', this.removeAbortListener) + .on('error', this.removeAbortListener) + } + this.callback = null this.res = body if (callback !== null) { @@ -123,8 +148,6 @@ class RequestHandler extends AsyncResource { onComplete (trailers) { const { res } = this - removeSignal(this) - util.parseHeaders(trailers, this.trailers) res.push(null) @@ -133,7 +156,8 @@ class RequestHandler extends AsyncResource { onError (err) { const { res, callback, body, opaque } = this - removeSignal(this) + this.removeAbortListener?.() + this.removeAbortListener = null if (callback) { // TODO: Does this need queueMicrotask? diff --git a/test/issue-2590.js b/test/issue-2590.js index c5499bf4513..1da0b23f20a 100644 --- a/test/issue-2590.js +++ b/test/issue-2590.js @@ -27,12 +27,12 @@ test('aborting request with custom reason', async (t) => { await t.rejects( request(`http://localhost:${server.address().port}`, { signal: ac.signal }), - /Request aborted/ + /Error: aborted/ ) await t.rejects( request(`http://localhost:${server.address().port}`, { signal: ac2.signal }), - { code: 'UND_ERR_ABORTED' } + { name: 'AbortError' } ) await t.completed diff --git a/test/request-signal.js b/test/request-signal.js new file mode 100644 index 00000000000..fd4d2f885a5 --- /dev/null +++ b/test/request-signal.js @@ -0,0 +1,76 @@ +'use strict' + +const { createServer } = require('node:http') +const { test, after } = require('node:test') +const { tspl } = require('@matteo.collina/tspl') +const { request } = require('..') + +test('pre abort signal w/ reason', async (t) => { + t = tspl(t, { plan: 1 }) + + const server = createServer((req, res) => { + res.end('asd') + }) + after(() => server.close()) + + server.listen(0, async () => { + const ac = new AbortController() + const _err = new Error() + ac.abort(_err) + try { + await request(`http://0.0.0.0:${server.address().port}`, { signal: ac.signal }) + } catch (err) { + t.equal(err, _err) + } + }) + await t.completed +}) + +test('post abort signal', async (t) => { + t = tspl(t, { plan: 1 }) + + const server = createServer((req, res) => { + res.end('asd') + }) + after(() => server.close()) + + server.listen(0, async () => { + const ac = new AbortController() + const ures = await request(`http://0.0.0.0:${server.address().port}`, { signal: ac.signal }) + ac.abort() + try { + /* eslint-disable-next-line no-unused-vars */ + for await (const chunk of ures.body) { + // Do nothing... + } + } catch (err) { + t.equal(err.name, 'AbortError') + } + }) + await t.completed +}) + +test('post abort signal w/ reason', async (t) => { + t = tspl(t, { plan: 1 }) + + const server = createServer((req, res) => { + res.end('asd') + }) + after(() => server.close()) + + server.listen(0, async () => { + const ac = new AbortController() + const _err = new Error() + const ures = await request(`http://0.0.0.0:${server.address().port}`, { signal: ac.signal }) + ac.abort(_err) + try { + /* eslint-disable-next-line no-unused-vars */ + for await (const chunk of ures.body) { + // Do nothing... + } + } catch (err) { + t.equal(err, _err) + } + }) + await t.completed +})