Skip to content

Commit

Permalink
Add Streaming Client interaction test (#4132)
Browse files Browse the repository at this point in the history
With this change, we are adding tests for interaction between the
producer and consumer client.
The test scenarion being added is where we kill the consumer client and
make sure that producer client kills the stream as well
  • Loading branch information
udsamani committed Jun 27, 2024
1 parent 960ceb1 commit 0eedc64
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 53 deletions.
11 changes: 7 additions & 4 deletions pkg/nats/proxy/compute_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ func handleRequest(msg *nats.Msg, handler *ComputeHandler) {

// processAndRespond processes the request and sends a response.
func processAndRespond[Request, Response any](
ctx context.Context, conn *nats.Conn, msg *nats.Msg, f handlerWithResponse[Request, Response]) {
ctx context.Context, conn *nats.Conn, msg *nats.Msg, f handlerWithResponse[Request, Response],
) {
response, err := processRequest(ctx, msg, f)
if err != nil {
log.Ctx(ctx).Error().Err(err)
Expand All @@ -109,7 +110,8 @@ func processAndRespond[Request, Response any](

// processRequest decodes the request, invokes the handler, and returns the response.
func processRequest[Request, Response any](
ctx context.Context, msg *nats.Msg, f handlerWithResponse[Request, Response]) (*Response, error) {
ctx context.Context, msg *nats.Msg, f handlerWithResponse[Request, Response],
) (*Response, error) {
request := new(Request)
err := json.Unmarshal(msg.Data, request)
if err != nil {
Expand All @@ -135,7 +137,8 @@ func sendResponse[Response any](conn *nats.Conn, reply string, result *concurren
}

func processAndStream[Request, Response any](ctx context.Context, streamingClient *stream.ProducerClient, msg *nats.Msg,
f handlerWithResponse[Request, <-chan *concurrency.AsyncResult[Response]]) {
f handlerWithResponse[Request, <-chan *concurrency.AsyncResult[Response]],
) {
if msg.Reply == "" {
log.Ctx(ctx).Error().Msgf("streaming request on %s has no reply subject", msg.Subject)
return
Expand Down Expand Up @@ -173,7 +176,7 @@ func processAndStream[Request, Response any](ctx context.Context, streamingClien
cancel,
)

defer streamingClient.RemoveStream(streamRequest.ConsumerID, streamRequest.ConsumerID)
defer streamingClient.RemoveStream(streamRequest.ConsumerID, streamRequest.StreamID) //nolint:errcheck
if err != nil {
_ = writer.CloseWithCode(stream.CloseInternalServerErr,
fmt.Sprintf("error in handler %s: %s", reflect.TypeOf(request).Name(), err))
Expand Down
2 changes: 1 addition & 1 deletion pkg/nats/stream/consumer_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ func (nc *ConsumerClient) getNotActiveStreamIds(activeStreamIDsAtProducer map[st
return time.Since(bucket.createdAt) < nc.config.StreamCancellationBufferDuration
})

// If no non recent buckets, means all are active streams
// If no non-recent buckets, means all are active streams
if len(nonRecentBuckets) == 0 {
continue
}
Expand Down
113 changes: 65 additions & 48 deletions pkg/nats/stream/producer_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/nats-io/nats.go"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
)

type ProducerClientParams struct {
Expand All @@ -18,22 +19,39 @@ type ProducerClientParams struct {

type ProducerClient struct {
Conn *nats.Conn
mu sync.RWMutex // Protects access to activeStreamInfo and activeConnHeartBeatRequestSubjects

// A map of ConsumerID to StreamId that are active
activeStreamInfo map[string]map[string]StreamInfo
// A map of ConsumerID to the subject where a heartBeatRequest needs to be sent.
activeConnHeartBeatRequestSubjects map[string]string
heartBeatCancelFunc context.CancelFunc
config StreamProducerClientConfig
mu sync.RWMutex // Protects access to activeConsumers

activeConsumers map[string]consumerInfo
heartBeatCancelFunc context.CancelFunc
config StreamProducerClientConfig
}

type consumerInfo struct {
// Heartbeat request subject to which consumer info subscribes to respond
// with non-active stream ids
HeartbeatRequestSub string
// A map holding information about active streams alive at consumer
ActiveStreamInfo map[string]StreamInfo
}

func (c *consumerInfo) getActiveStreamIds() []string {
return lo.Keys(c.ActiveStreamInfo)
}

func (c *consumerInfo) getActiveStreamIdsByRequestSubject() map[string][]string {
activeStreamIdsByReqSubj := make(map[string][]string)

for streamID, streamInfo := range c.ActiveStreamInfo {
activeStreamIdsByReqSubj[streamInfo.RequestSub] = append(activeStreamIdsByReqSubj[streamInfo.RequestSub], streamID)
}
return activeStreamIdsByReqSubj
}

func NewProducerClient(ctx context.Context, params ProducerClientParams) (*ProducerClient, error) {
nc := &ProducerClient{
Conn: params.Conn,
activeStreamInfo: make(map[string]map[string]StreamInfo),
activeConnHeartBeatRequestSubjects: make(map[string]string),
config: params.Config,
Conn: params.Conn,
activeConsumers: make(map[string]consumerInfo),
config: params.Config,
}

go nc.heartBeat(ctx)
Expand All @@ -51,46 +69,53 @@ func (pc *ProducerClient) AddStream(
pc.mu.Lock()
defer pc.mu.Unlock()

if _, ok := pc.activeConsumers[consumerID]; !ok {
pc.activeConsumers[consumerID] = consumerInfo{
HeartbeatRequestSub: heartBeatRequestSub,
ActiveStreamInfo: make(map[string]StreamInfo),
}
}

if _, ok := pc.activeConsumers[consumerID].ActiveStreamInfo[streamID]; ok {
return fmt.Errorf("cannot create request with same streamId %s again", streamID)
}

streamInfo := StreamInfo{
ID: streamID,
RequestSub: requestSub,
CreatedAt: time.Now(),
Cancel: cancelFunc,
}

if pc.activeStreamInfo[consumerID] == nil {
pc.activeStreamInfo[consumerID] = make(map[string]StreamInfo)
}

if _, ok := pc.activeStreamInfo[consumerID][streamID]; ok {
return fmt.Errorf("cannot create request with same streamId %s again", streamID)
}

pc.activeStreamInfo[consumerID][streamID] = streamInfo
pc.activeConnHeartBeatRequestSubjects[consumerID] = heartBeatRequestSub

pc.activeConsumers[consumerID].ActiveStreamInfo[streamID] = streamInfo
return nil
}

func (pc *ProducerClient) RemoveStream(consumerID string, streamID string) {
func (pc *ProducerClient) RemoveStream(consumerID string, streamID string) error {
pc.mu.Lock()
defer pc.mu.Unlock()

activeStreamIdsForConn, ok := pc.activeStreamInfo[consumerID]
consumer, ok := pc.activeConsumers[consumerID]
if !ok {
return
return fmt.Errorf("consumer %s not found", consumerID)
}

activeStreamIdsForConn := consumer.ActiveStreamInfo
if activeStreamIdsForConn == nil {
return fmt.Errorf("active stream Ids for consumer %s is nil", consumerID)
}

if _, ok := activeStreamIdsForConn[streamID]; !ok {
return
return fmt.Errorf("no stream with id %s found for consumer %s", streamID, consumerID)
}

delete(activeStreamIdsForConn, streamID)

if len(activeStreamIdsForConn) == 0 {
delete(pc.activeStreamInfo, consumerID)
delete(pc.activeConnHeartBeatRequestSubjects, consumerID)
delete(pc.activeConsumers, consumerID)
}

return nil
}

func (pc *ProducerClient) heartBeat(ctx context.Context) {
Expand All @@ -110,18 +135,9 @@ func (pc *ProducerClient) heartBeat(ctx context.Context) {
nonActiveStreamIds := make(map[string][]string)
pc.mu.RLock()

for c, v := range pc.activeConnHeartBeatRequestSubjects {
// Create an empty slice for activeStreamIds
activeStreamIds := make(map[string][]string)

if streamInfoMap, ok := pc.activeStreamInfo[c]; ok {
for _, streamInfo := range streamInfoMap {
activeStreamIds[streamInfo.RequestSub] = append(activeStreamIds[streamInfo.RequestSub], streamInfo.ID)
}
}

for c, v := range pc.activeConsumers {
heartBeatRequest := HeartBeatRequest{
ActiveStreamIds: activeStreamIds,
ActiveStreamIds: v.getActiveStreamIdsByRequestSubject(),
}

data, err := json.Marshal(heartBeatRequest)
Expand All @@ -130,9 +146,10 @@ func (pc *ProducerClient) heartBeat(ctx context.Context) {
continue
}

msg, err := pc.Conn.Request(v, data, pc.config.HeartBeatRequestTimeout)
msg, err := pc.Conn.Request(v.HeartbeatRequestSub, data, pc.config.HeartBeatRequestTimeout)
if err != nil {
log.Ctx(ctx).Err(err).Msg("error while sending heart beat request from NATS streaming producer client")
log.Ctx(ctx).Err(err).Msg("heartbeat request to consumer client timed out")
nonActiveStreamIds[c] = v.getActiveStreamIds()
continue
}

Expand Down Expand Up @@ -165,17 +182,17 @@ func (pc *ProducerClient) updateActiveStreamInfo(nonActiveStreamIds map[string][
nonActiveMap[id] = true
}

if streamInfo, ok := pc.activeStreamInfo[connID]; ok {
for streamID := range streamInfo {
if consumer, ok := pc.activeConsumers[connID]; ok {
for streamID := range consumer.ActiveStreamInfo {
if nonActiveMap[streamID] {
streamInfo := pc.activeStreamInfo[connID][streamID]
streamInfo := consumer.ActiveStreamInfo[streamID]
streamInfo.Cancel()
delete(pc.activeStreamInfo[connID], streamID)
delete(consumer.ActiveStreamInfo, streamID)
}
}
// If after deletion, there's no stream left for this connection, delete the connection
if len(pc.activeStreamInfo[connID]) == 0 {
delete(pc.activeStreamInfo, connID)
if len(consumer.ActiveStreamInfo) == 0 {
delete(pc.activeConsumers, connID)
}
}
}
Expand Down
Loading

0 comments on commit 0eedc64

Please sign in to comment.