From 0eedc64fc1417189d5a65d5e38eb21416d8efd16 Mon Sep 17 00:00:00 2001 From: Udit Samani Date: Thu, 27 Jun 2024 12:01:17 +0100 Subject: [PATCH] Add Streaming Client interaction test (#4132) With this change, we are adding tests for interaction between the producer and consumer client. The test scenarion being added is where we kill the consumer client and make sure that producer client kills the stream as well --- pkg/nats/proxy/compute_handler.go | 11 +- pkg/nats/stream/consumer_client.go | 2 +- pkg/nats/stream/producer_client.go | 113 ++++++----- .../streaming_client_interactions_test.go | 185 ++++++++++++++++++ 4 files changed, 258 insertions(+), 53 deletions(-) create mode 100644 pkg/nats/stream/streaming_client_interactions_test.go diff --git a/pkg/nats/proxy/compute_handler.go b/pkg/nats/proxy/compute_handler.go index 6de200481f..1727f00995 100644 --- a/pkg/nats/proxy/compute_handler.go +++ b/pkg/nats/proxy/compute_handler.go @@ -92,7 +92,8 @@ func handleRequest(msg *nats.Msg, handler *ComputeHandler) { // processAndRespond processes the request and sends a response. func processAndRespond[Request, Response any]( - ctx context.Context, conn *nats.Conn, msg *nats.Msg, f handlerWithResponse[Request, Response]) { + ctx context.Context, conn *nats.Conn, msg *nats.Msg, f handlerWithResponse[Request, Response], +) { response, err := processRequest(ctx, msg, f) if err != nil { log.Ctx(ctx).Error().Err(err) @@ -109,7 +110,8 @@ func processAndRespond[Request, Response any]( // processRequest decodes the request, invokes the handler, and returns the response. func processRequest[Request, Response any]( - ctx context.Context, msg *nats.Msg, f handlerWithResponse[Request, Response]) (*Response, error) { + ctx context.Context, msg *nats.Msg, f handlerWithResponse[Request, Response], +) (*Response, error) { request := new(Request) err := json.Unmarshal(msg.Data, request) if err != nil { @@ -135,7 +137,8 @@ func sendResponse[Response any](conn *nats.Conn, reply string, result *concurren } func processAndStream[Request, Response any](ctx context.Context, streamingClient *stream.ProducerClient, msg *nats.Msg, - f handlerWithResponse[Request, <-chan *concurrency.AsyncResult[Response]]) { + f handlerWithResponse[Request, <-chan *concurrency.AsyncResult[Response]], +) { if msg.Reply == "" { log.Ctx(ctx).Error().Msgf("streaming request on %s has no reply subject", msg.Subject) return @@ -173,7 +176,7 @@ func processAndStream[Request, Response any](ctx context.Context, streamingClien cancel, ) - defer streamingClient.RemoveStream(streamRequest.ConsumerID, streamRequest.ConsumerID) + defer streamingClient.RemoveStream(streamRequest.ConsumerID, streamRequest.StreamID) //nolint:errcheck if err != nil { _ = writer.CloseWithCode(stream.CloseInternalServerErr, fmt.Sprintf("error in handler %s: %s", reflect.TypeOf(request).Name(), err)) diff --git a/pkg/nats/stream/consumer_client.go b/pkg/nats/stream/consumer_client.go index 588fb6bbf3..f996763482 100644 --- a/pkg/nats/stream/consumer_client.go +++ b/pkg/nats/stream/consumer_client.go @@ -345,7 +345,7 @@ func (nc *ConsumerClient) getNotActiveStreamIds(activeStreamIDsAtProducer map[st return time.Since(bucket.createdAt) < nc.config.StreamCancellationBufferDuration }) - // If no non recent buckets, means all are active streams + // If no non-recent buckets, means all are active streams if len(nonRecentBuckets) == 0 { continue } diff --git a/pkg/nats/stream/producer_client.go b/pkg/nats/stream/producer_client.go index 2845a977b1..041413bc41 100644 --- a/pkg/nats/stream/producer_client.go +++ b/pkg/nats/stream/producer_client.go @@ -9,6 +9,7 @@ import ( "github.com/nats-io/nats.go" "github.com/rs/zerolog/log" + "github.com/samber/lo" ) type ProducerClientParams struct { @@ -18,22 +19,39 @@ type ProducerClientParams struct { type ProducerClient struct { Conn *nats.Conn - mu sync.RWMutex // Protects access to activeStreamInfo and activeConnHeartBeatRequestSubjects - - // A map of ConsumerID to StreamId that are active - activeStreamInfo map[string]map[string]StreamInfo - // A map of ConsumerID to the subject where a heartBeatRequest needs to be sent. - activeConnHeartBeatRequestSubjects map[string]string - heartBeatCancelFunc context.CancelFunc - config StreamProducerClientConfig + mu sync.RWMutex // Protects access to activeConsumers + + activeConsumers map[string]consumerInfo + heartBeatCancelFunc context.CancelFunc + config StreamProducerClientConfig +} + +type consumerInfo struct { + // Heartbeat request subject to which consumer info subscribes to respond + // with non-active stream ids + HeartbeatRequestSub string + // A map holding information about active streams alive at consumer + ActiveStreamInfo map[string]StreamInfo +} + +func (c *consumerInfo) getActiveStreamIds() []string { + return lo.Keys(c.ActiveStreamInfo) +} + +func (c *consumerInfo) getActiveStreamIdsByRequestSubject() map[string][]string { + activeStreamIdsByReqSubj := make(map[string][]string) + + for streamID, streamInfo := range c.ActiveStreamInfo { + activeStreamIdsByReqSubj[streamInfo.RequestSub] = append(activeStreamIdsByReqSubj[streamInfo.RequestSub], streamID) + } + return activeStreamIdsByReqSubj } func NewProducerClient(ctx context.Context, params ProducerClientParams) (*ProducerClient, error) { nc := &ProducerClient{ - Conn: params.Conn, - activeStreamInfo: make(map[string]map[string]StreamInfo), - activeConnHeartBeatRequestSubjects: make(map[string]string), - config: params.Config, + Conn: params.Conn, + activeConsumers: make(map[string]consumerInfo), + config: params.Config, } go nc.heartBeat(ctx) @@ -51,6 +69,17 @@ func (pc *ProducerClient) AddStream( pc.mu.Lock() defer pc.mu.Unlock() + if _, ok := pc.activeConsumers[consumerID]; !ok { + pc.activeConsumers[consumerID] = consumerInfo{ + HeartbeatRequestSub: heartBeatRequestSub, + ActiveStreamInfo: make(map[string]StreamInfo), + } + } + + if _, ok := pc.activeConsumers[consumerID].ActiveStreamInfo[streamID]; ok { + return fmt.Errorf("cannot create request with same streamId %s again", streamID) + } + streamInfo := StreamInfo{ ID: streamID, RequestSub: requestSub, @@ -58,39 +87,35 @@ func (pc *ProducerClient) AddStream( Cancel: cancelFunc, } - if pc.activeStreamInfo[consumerID] == nil { - pc.activeStreamInfo[consumerID] = make(map[string]StreamInfo) - } - - if _, ok := pc.activeStreamInfo[consumerID][streamID]; ok { - return fmt.Errorf("cannot create request with same streamId %s again", streamID) - } - - pc.activeStreamInfo[consumerID][streamID] = streamInfo - pc.activeConnHeartBeatRequestSubjects[consumerID] = heartBeatRequestSub - + pc.activeConsumers[consumerID].ActiveStreamInfo[streamID] = streamInfo return nil } -func (pc *ProducerClient) RemoveStream(consumerID string, streamID string) { +func (pc *ProducerClient) RemoveStream(consumerID string, streamID string) error { pc.mu.Lock() defer pc.mu.Unlock() - activeStreamIdsForConn, ok := pc.activeStreamInfo[consumerID] + consumer, ok := pc.activeConsumers[consumerID] if !ok { - return + return fmt.Errorf("consumer %s not found", consumerID) + } + + activeStreamIdsForConn := consumer.ActiveStreamInfo + if activeStreamIdsForConn == nil { + return fmt.Errorf("active stream Ids for consumer %s is nil", consumerID) } if _, ok := activeStreamIdsForConn[streamID]; !ok { - return + return fmt.Errorf("no stream with id %s found for consumer %s", streamID, consumerID) } delete(activeStreamIdsForConn, streamID) if len(activeStreamIdsForConn) == 0 { - delete(pc.activeStreamInfo, consumerID) - delete(pc.activeConnHeartBeatRequestSubjects, consumerID) + delete(pc.activeConsumers, consumerID) } + + return nil } func (pc *ProducerClient) heartBeat(ctx context.Context) { @@ -110,18 +135,9 @@ func (pc *ProducerClient) heartBeat(ctx context.Context) { nonActiveStreamIds := make(map[string][]string) pc.mu.RLock() - for c, v := range pc.activeConnHeartBeatRequestSubjects { - // Create an empty slice for activeStreamIds - activeStreamIds := make(map[string][]string) - - if streamInfoMap, ok := pc.activeStreamInfo[c]; ok { - for _, streamInfo := range streamInfoMap { - activeStreamIds[streamInfo.RequestSub] = append(activeStreamIds[streamInfo.RequestSub], streamInfo.ID) - } - } - + for c, v := range pc.activeConsumers { heartBeatRequest := HeartBeatRequest{ - ActiveStreamIds: activeStreamIds, + ActiveStreamIds: v.getActiveStreamIdsByRequestSubject(), } data, err := json.Marshal(heartBeatRequest) @@ -130,9 +146,10 @@ func (pc *ProducerClient) heartBeat(ctx context.Context) { continue } - msg, err := pc.Conn.Request(v, data, pc.config.HeartBeatRequestTimeout) + msg, err := pc.Conn.Request(v.HeartbeatRequestSub, data, pc.config.HeartBeatRequestTimeout) if err != nil { - log.Ctx(ctx).Err(err).Msg("error while sending heart beat request from NATS streaming producer client") + log.Ctx(ctx).Err(err).Msg("heartbeat request to consumer client timed out") + nonActiveStreamIds[c] = v.getActiveStreamIds() continue } @@ -165,17 +182,17 @@ func (pc *ProducerClient) updateActiveStreamInfo(nonActiveStreamIds map[string][ nonActiveMap[id] = true } - if streamInfo, ok := pc.activeStreamInfo[connID]; ok { - for streamID := range streamInfo { + if consumer, ok := pc.activeConsumers[connID]; ok { + for streamID := range consumer.ActiveStreamInfo { if nonActiveMap[streamID] { - streamInfo := pc.activeStreamInfo[connID][streamID] + streamInfo := consumer.ActiveStreamInfo[streamID] streamInfo.Cancel() - delete(pc.activeStreamInfo[connID], streamID) + delete(consumer.ActiveStreamInfo, streamID) } } // If after deletion, there's no stream left for this connection, delete the connection - if len(pc.activeStreamInfo[connID]) == 0 { - delete(pc.activeStreamInfo, connID) + if len(consumer.ActiveStreamInfo) == 0 { + delete(pc.activeConsumers, connID) } } } diff --git a/pkg/nats/stream/streaming_client_interactions_test.go b/pkg/nats/stream/streaming_client_interactions_test.go new file mode 100644 index 0000000000..eac5931613 --- /dev/null +++ b/pkg/nats/stream/streaming_client_interactions_test.go @@ -0,0 +1,185 @@ +//go:build unit || !integration + +package stream + +import ( + "context" + "encoding/json" + "github.com/bacalhau-project/bacalhau/pkg/lib/network" + nats_helper "github.com/bacalhau-project/bacalhau/pkg/nats" + "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/suite" + "testing" + "time" +) + +const subjectName = "topic.stream" +const testString = "Hello from bacalhau" + +type StreamingClientInteractionTestSuite struct { + suite.Suite + + ctx context.Context + natServer *server.Server + pc *ProducerClient + cc *ConsumerClient +} + +type testData struct { + contextCancelled bool + streamReplySub string + heartBeatRequestSub string +} + +func (s *StreamingClientInteractionTestSuite) SetupSuite() { + s.ctx = context.Background() + s.natServer = s.createNatsServer() + s.pc = s.createProducerClient() + s.cc = s.createConsumerClient() + +} + +func (s *StreamingClientInteractionTestSuite) TearDownSuite() { + s.cc.Conn.Close() + s.pc.Conn.Close() + s.natServer.Shutdown() +} + +func (s *StreamingClientInteractionTestSuite) createNatsServer() *server.Server { + ctx := context.Background() + port, err := network.GetFreePort() + s.Require().NoError(err) + + serverOpts := server.Options{ + Port: port, + } + + ns, err := nats_helper.NewServerManager(ctx, nats_helper.ServerManagerParams{ + Options: &serverOpts, + }) + s.Require().NoError(err) + return ns.Server +} + +func (s *StreamingClientInteractionTestSuite) createProducerClient() *ProducerClient { + clientManager, err := nats_helper.NewClientManager(s.ctx, s.natServer.ClientURL(), nats.Name("streaming-test")) + s.Require().NoError(err) + + pc, err := NewProducerClient(s.ctx, ProducerClientParams{ + Conn: clientManager.Client, + Config: StreamProducerClientConfig{ + HeartBeatIntervalDuration: 100 * time.Millisecond, + HeartBeatRequestTimeout: 50 * time.Millisecond, + StreamCancellationBufferDuration: 100 * time.Millisecond, + }, + }) + + s.Require().NoError(err) + return pc +} + +func (s *StreamingClientInteractionTestSuite) createConsumerClient() *ConsumerClient { + + clientManager, err := nats_helper.NewClientManager(s.ctx, s.natServer.ClientURL(), nats.Name("streaming-test")) + s.Require().NoError(err) + + cc, err := NewConsumerClient(ConsumerClientParams{ + Conn: clientManager.Client, + Config: StreamConsumerClientConfig{ + StreamCancellationBufferDuration: 50 * time.Millisecond, + }, + }) + + s.Require().NoError(err) + return cc +} + +func TestStreamingClientTestSuit(t *testing.T) { + suite.Run(t, new(StreamingClientInteractionTestSuite)) +} + +func (s *StreamingClientInteractionTestSuite) TestStreamConsumerClientGoingDown() { + + // Set up for the test + td := &testData{} + clientManager, err := nats_helper.NewClientManager(s.ctx, s.natServer.ClientURL(), nats.Name("stream-testing-consumer-going-down")) + s.Require().NoError(err) + + // Produce some data once asked for + ctx, cancel := context.WithCancel(s.ctx) + _, err = clientManager.Client.Subscribe(subjectName, func(msg *nats.Msg) { + + s.Require().NotNil(msg) + + var streamRequest Request + err := json.Unmarshal(msg.Data, &streamRequest) + s.Require().NoError(err) + + err = s.pc.AddStream( + streamRequest.ConsumerID, + streamRequest.StreamID, + msg.Subject, + streamRequest.HeartBeatRequestSub, + cancel, + ) + s.Require().NoError(err) + + td.heartBeatRequestSub = streamRequest.HeartBeatRequestSub + go func() { + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + s.Require().NotNil(ctx.Err()) + td.contextCancelled = true + return + case <-ticker.C: + data, err := json.Marshal(testString) + s.Require().NoError(err) + + sMsg := StreamingMsg{ + Type: 1, + Data: data, + } + + sMsgData, err := json.Marshal(sMsg) + s.Require().NoError(err) + + clientManager.Client.Publish(msg.Reply, sMsgData) + } + } + + }() + }) + s.Require().NoError(err) + data, err := json.Marshal(testString) + s.Require().NoError(err) + + _, err = s.cc.OpenStream(s.ctx, subjectName, data) + s.Require().NoError(err) + + // Close the Consumer Client After Certain Time + go func() { + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.cc.Conn.Close() + return + } + } + + }() + + // Validate that producer client does the cleanup + s.Eventually(func() bool { + return td.contextCancelled + }, 1800*time.Millisecond, 100*time.Millisecond) + +}