diff --git a/server/broadcast_test.go b/server/broadcast_test.go index b558484cb4..bdca214a80 100644 --- a/server/broadcast_test.go +++ b/server/broadcast_test.go @@ -1289,7 +1289,7 @@ func TestRefreshSession(t *testing.T) { defer func() { getOrchestratorInfoRPC = oldGetOrchestratorInfoRPC }() // trigger parse URL error - sess := StubBroadcastSession(string(0x7f)) + sess := StubBroadcastSession(string(rune(0x7f))) newSess, err := refreshSession(sess) assert.Nil(newSess) assert.Error(err) diff --git a/server/mediaserver.go b/server/mediaserver.go index 62ab73486e..5cab5e25ee 100644 --- a/server/mediaserver.go +++ b/server/mediaserver.go @@ -88,10 +88,11 @@ type LivepeerServer struct { // Thread sensitive fields. All accesses to the // following fields should be protected by `connectionLock` - rtmpConnections map[core.ManifestID]*rtmpConnection - lastHLSStreamID core.StreamID - lastManifestID core.ManifestID - connectionLock *sync.RWMutex + rtmpConnections map[core.ManifestID]*rtmpConnection + internalManifests map[core.ManifestID]core.ManifestID + lastHLSStreamID core.StreamID + lastManifestID core.ManifestID + connectionLock *sync.RWMutex } type authWebhookResponse struct { @@ -146,7 +147,8 @@ func NewLivepeerServer(rtmpAddr string, lpNode *core.LivepeerNode, httpIngest bo } server := lpmscore.New(&opts) ls := &LivepeerServer{RTMPSegmenter: server, LPMS: server, LivepeerNode: lpNode, HTTPMux: opts.HttpMux, connectionLock: &sync.RWMutex{}, - rtmpConnections: make(map[core.ManifestID]*rtmpConnection), + rtmpConnections: make(map[core.ManifestID]*rtmpConnection), + internalManifests: make(map[core.ManifestID]core.ManifestID), } if lpNode.NodeType == core.BroadcasterNode && httpIngest { opts.HttpMux.HandleFunc("/live/", ls.HandlePush) @@ -419,7 +421,7 @@ func endRTMPStreamHandler(s *LivepeerServer) func(url *url.URL, rtmpStrm stream. } //Remove RTMP stream - err := removeRTMPStream(s, params.ManifestID) + err := removeRTMPStream(s, params.ManifestID, params.ManifestID) if err != nil { return err } @@ -510,7 +512,7 @@ func (s *LivepeerServer) registerConnection(rtmpStrm stream.RTMPVideoStream) (*r return cxn, nil } -func removeRTMPStream(s *LivepeerServer, mid core.ManifestID) error { +func removeRTMPStream(s *LivepeerServer, mid, extmid core.ManifestID) error { s.connectionLock.Lock() defer s.connectionLock.Unlock() cxn, ok := s.rtmpConnections[mid] @@ -523,6 +525,9 @@ func removeRTMPStream(s *LivepeerServer, mid core.ManifestID) error { cxn.pl.Cleanup() glog.Infof("Ended stream with id=%s", mid) delete(s.rtmpConnections, mid) + if mid != extmid { + delete(s.internalManifests, extmid) + } if monitor.Enabled { monitor.StreamEnded(cxn.nonce) @@ -668,12 +673,15 @@ func (s *LivepeerServer) HandlePush(w http.ResponseWriter, r *http.Request) { http.Error(w, `Bad URL`, http.StatusBadRequest) return } - s.connectionLock.Lock() + s.connectionLock.RLock() + if intmid, exists := s.internalManifests[mid]; exists { + mid = intmid + } cxn, exists := s.rtmpConnections[mid] if exists && cxn != nil { cxn.lastUsed = now } - s.connectionLock.Unlock() + s.connectionLock.RUnlock() // Check for presence and register if a fresh cxn if !exists { @@ -701,7 +709,7 @@ func (s *LivepeerServer) HandlePush(w http.ResponseWriter, r *http.Request) { // Start a watchdog to remove session after a period of inactivity ticker := time.NewTicker(httpPushTimeout) - go func(s *LivepeerServer, mid core.ManifestID) { + go func(s *LivepeerServer, intmid, extmid core.ManifestID) { defer ticker.Stop() for range ticker.C { var lastUsed time.Time @@ -711,11 +719,18 @@ func (s *LivepeerServer) HandlePush(w http.ResponseWriter, r *http.Request) { } s.connectionLock.RUnlock() if time.Since(lastUsed) > httpPushTimeout { - _ = removeRTMPStream(s, mid) + _ = removeRTMPStream(s, intmid, extmid) return } } - }(s, mid) + }(s, cxn.mid, mid) + if cxn.mid != mid { + // AuthWebhook provided different ManifestID + s.connectionLock.Lock() + s.internalManifests[mid] = cxn.mid + s.connectionLock.Unlock() + mid = cxn.mid + } } fname := path.Base(r.URL.Path) @@ -749,11 +764,11 @@ func (s *LivepeerServer) HandlePush(w http.ResponseWriter, r *http.Request) { return case <-tick.Done(): glog.V(common.VERBOSE).Infof("watchdog reset mid=%s seq=%d dur=%v started=%v", mid, seq, duration, now) - s.connectionLock.Lock() + s.connectionLock.RLock() if cxn, exists := s.rtmpConnections[mid]; exists { cxn.lastUsed = time.Now() } - s.connectionLock.Unlock() + s.connectionLock.RUnlock() } } }() diff --git a/server/push_test.go b/server/push_test.go index aa55281e40..10d3db736a 100644 --- a/server/push_test.go +++ b/server/push_test.go @@ -493,14 +493,16 @@ func TestPush_SetVideoProfileFormats(t *testing.T) { assert.Equal(ffmpeg.FormatNone, BroadcastJobVideoProfiles[i].Format) } + hookCalled := 0 // Sanity check that default profile with webhook is copied // Checking since there is special handling for the default set of profiles // within the webhook hander. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth := authWebhookResponse{ManifestID: "web"} + auth := authWebhookResponse{ManifestID: "intweb"} val, err := json.Marshal(auth) assert.Nil(err, "invalid auth webhook response") w.Write(val) + hookCalled++ })) defer ts.Close() oldURL := AuthWebhookURL @@ -512,9 +514,12 @@ func TestPush_SetVideoProfileFormats(t *testing.T) { h.ServeHTTP(w, req) resp = w.Result() defer resp.Body.Close() + assert.Equal(1, hookCalled) assert.Len(s.rtmpConnections, 3) cxn, ok = s.rtmpConnections["web"] + assert.False(ok, "stream should not exist") + cxn, ok = s.rtmpConnections["intweb"] assert.True(ok, "stream did not exist") assert.Equal(ffmpeg.FormatMP4, cxn.profile.Format) assert.Len(cxn.params.Profiles, 2) @@ -523,6 +528,69 @@ func TestPush_SetVideoProfileFormats(t *testing.T) { assert.Equal(ffmpeg.FormatMP4, p.Format) assert.Equal(ffmpeg.FormatNone, BroadcastJobVideoProfiles[i].Format) } + // Server has empty sessions list, so it will return 503 + assert.Equal(503, resp.StatusCode) + + h, r, w = requestSetup(s) + req = httptest.NewRequest("POST", "/live/web/1.mp4", r) + h.ServeHTTP(w, req) + resp = w.Result() + defer resp.Body.Close() + // webhook should not be called again + assert.Equal(1, hookCalled) + + assert.Len(s.rtmpConnections, 3) + cxn, ok = s.rtmpConnections["web"] + assert.False(ok, "stream should not exist") + cxn, ok = s.rtmpConnections["intweb"] + assert.True(ok, "stream did not exist") + assert.Equal(503, resp.StatusCode) +} + +func TestPush_ShouldRemoveSessionAfterTimeoutIfInternalMIDIsUsed(t *testing.T) { + defer goleak.VerifyNone(t, ignoreRoutines()...) + + oldRI := httpPushTimeout + httpPushTimeout = 2 * time.Millisecond + defer func() { httpPushTimeout = oldRI }() + assert := assert.New(t) + s, cancel := setupServerWithCancel() + + hookCalled := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := authWebhookResponse{ManifestID: "intmid"} + val, err := json.Marshal(auth) + assert.Nil(err, "invalid auth webhook response") + w.Write(val) + hookCalled++ + })) + defer ts.Close() + oldURL := AuthWebhookURL + defer func() { AuthWebhookURL = oldURL }() + AuthWebhookURL = ts.URL + + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/live/extmid1/1.ts", nil) + s.HandlePush(w, req) + resp := w.Result() + resp.Body.Close() + assert.Equal(1, hookCalled) + s.connectionLock.Lock() + _, exists := s.rtmpConnections["intmid"] + _, existsExt := s.rtmpConnections["extmid1"] + intmid := s.internalManifests["extmid1"] + s.connectionLock.Unlock() + assert.Equal("intmid", string(intmid)) + assert.True(exists) + assert.False(existsExt) + time.Sleep(50 * time.Millisecond) + s.connectionLock.Lock() + _, exists = s.rtmpConnections["intmid"] + _, extEx := s.internalManifests["extmid1"] + s.connectionLock.Unlock() + cancel() + assert.False(exists) + assert.False(extEx) } func ignoreRoutines() []goleak.Option { @@ -747,7 +815,7 @@ func TestPush_StorageError(t *testing.T) { drivers.NodeStorage = nil req := httptest.NewRequest("POST", "/live/seg.ts", reader) mid := parseManifestID(req.URL.Path) - err := removeRTMPStream(s, mid) + err := removeRTMPStream(s, mid, mid) handler.ServeHTTP(w, req) resp := w.Result()