diff --git a/go.mod b/go.mod index f80b2bfb..fe4cc2e4 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/xtaci/kcp-go/v5 require ( + github.com/OneOfOne/xxhash v1.2.2 + github.com/cespare/xxhash v1.1.0 // indirect github.com/klauspost/cpuid v1.3.1 // indirect github.com/klauspost/reedsolomon v1.9.9 github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104 // indirect diff --git a/go.sum b/go.sum index ebb0cc1f..f5bdb3de 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/klauspost/cpuid v1.2.4/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/cpuid v1.3.1 h1:5JNjFYYQrZeKRJ0734q51WCEEn2huer72Dc7K+R/b6s= github.com/klauspost/cpuid v1.3.1/go.mod h1:bYW4mA6ZgKPob1/Dlai2LviZJO7KGI3uoWLd42rAQw4= @@ -7,6 +11,8 @@ github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104 h1:ULR/QWMgcgRiZLU github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104/go.mod h1:wqKykBG2QzQDJEzvRkcS8x6MiSJkF52hXZsXcjaB3ls= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/templexxx/cpu v0.0.1 h1:hY4WdLOgKdc8y13EYklu9OUTXik80BkxHoWvTO6MQQY= github.com/templexxx/cpu v0.0.1/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= github.com/templexxx/cpu v0.0.7 h1:pUEZn8JBy/w5yzdYWgx+0m0xL9uk6j4K91C5kOViAzo= diff --git a/map.go b/map.go new file mode 100644 index 00000000..38521ed3 --- /dev/null +++ b/map.go @@ -0,0 +1,106 @@ +package kcp + +import ( + "github.com/OneOfOne/xxhash" + "sync" + "sync/atomic" +) + +const ( + segmentCount = 256 + segmentMask = 255 +) + +type ConcurrentMap struct { + segments [segmentCount]map[string]interface{} + locks [segmentCount]sync.RWMutex + length int64 +} + +func (m *ConcurrentMap) Put(key string, val interface{}) { + keyHash := xxhash.Checksum32([]byte(key)) + segIdx := keyHash & segmentMask + m.locks[segIdx].Lock() + defer m.locks[segIdx].Unlock() + if seg := m.segments[segIdx]; seg != nil { + if _, ok := seg[key]; !ok { + atomic.AddInt64(&m.length, 1) + } + seg[key] = val + } else { + seg = make(map[string]interface{}) + seg[key] = val + m.segments[segIdx] = seg + atomic.AddInt64(&m.length, 1) + } +} + +func (m *ConcurrentMap) Get(key string) (interface{}, bool) { + keyHash := xxhash.Checksum32([]byte(key)) + segIdx := keyHash & segmentMask + m.locks[segIdx].RLock() + defer m.locks[segIdx].RUnlock() + if seg := m.segments[segIdx]; seg == nil { + return nil, false + } else { + val, ok := seg[key] + return val, ok + } +} + +func (m *ConcurrentMap) Del(key string) { + keyHash := xxhash.Checksum32([]byte(key)) + segIdx := keyHash & segmentMask + m.locks[segIdx].Lock() + defer m.locks[segIdx].Unlock() + if seg := m.segments[segIdx]; seg != nil { + if _, ok := seg[key]; ok { + atomic.AddInt64(&m.length, -1) + } + delete(seg, key) + } +} + +type IterFun func(key string, val interface{}) bool + +func (m *ConcurrentMap) Iterate(f IterFun) { + for i := 0; i < segmentCount; i++ { + func(i int) { + m.locks[i].RLock() + defer m.locks[i].RUnlock() + tmp := m.segments[i] + for k, v := range tmp { + if !f(k, v) { + break + } + } + }(i) + } +} + +func (m *ConcurrentMap) Length() int { + return int(m.length) +} + +func (m *ConcurrentMap) Filter(f IterFun) { + for i := 0; i < segmentCount; i++ { + func(i int) { + m.locks[i].Lock() + defer m.locks[i].Unlock() + tmp := m.segments[i] + var deleteKeys []string + for k, v := range tmp { + if !f(k, v) { + deleteKeys = append(deleteKeys, k) + } + } + for _, k := range deleteKeys { + delete(tmp, k) + } + }(i) + } +} + +func NewConcurrentMap() *ConcurrentMap { + return &ConcurrentMap{} +} diff --git a/map_test.go b/map_test.go new file mode 100644 index 00000000..5fae10b2 --- /dev/null +++ b/map_test.go @@ -0,0 +1,247 @@ +package kcp + +import ( + "math/rand" + "strconv" + "sync" + "testing" +) + +func BenchmarkConcurrentMap_Write(b *testing.B) { + dataLen := b.N + concurrency := 8 + workLoad := dataLen / concurrency + newMap := NewConcurrentMap() + wait := &sync.WaitGroup{} + + b.ResetTimer() + for i := 0; i < concurrency; i++ { + wait.Add(1) + go func(m *ConcurrentMap, start, end int, wait *sync.WaitGroup) { + for i := start; i < end; i++ { + m.Put(strconv.Itoa(i), i) + } + wait.Done() + }(newMap, i*workLoad, (i+1)*workLoad, wait) + } + wait.Wait() +} + +func BenchmarkSyncMap_Write(b *testing.B) { + oldMap := &sync.Map{} + dataLen := b.N + concurrency := 8 + workLoad := dataLen / concurrency + wait := &sync.WaitGroup{} + b.ResetTimer() + for i := 0; i < concurrency; i++ { + wait.Add(1) + go func(m *sync.Map, start, end int, wait *sync.WaitGroup) { + for i := start; i < end; i++ { + m.Store(strconv.Itoa(i), i) + } + wait.Done() + }(oldMap, i*workLoad, (i+1)*workLoad, wait) + } + wait.Wait() +} + +func BenchmarkConcurrentMap_ReadWrite(b *testing.B) { + m := NewConcurrentMap() + dataLen := 65536 + for i := 0; i < dataLen; i++ { + m.Put(strconv.Itoa(i), i) + } + concurrency := 8 + wait := &sync.WaitGroup{} + + b.ResetTimer() + for i := 0; i < concurrency; i++ { + wait.Add(1) + go func(w *sync.WaitGroup, m *ConcurrentMap) { + for j := 0; j < b.N; j++ { + l := rand.Intn(dataLen) + if l%2 == 0 { + m.Get(strconv.Itoa(l)) + } else { + m.Put(strconv.Itoa(l), l+1) + } + } + w.Done() + }(wait, m) + } + wait.Wait() +} + +func BenchmarkSyncMap_ReadWrite(b *testing.B) { + m := &sync.Map{} + dataLen := 65536 + for i := 0; i < dataLen; i++ { + m.Store(strconv.Itoa(i), i) + } + concurrency := 8 + wait := &sync.WaitGroup{} + + b.ResetTimer() + for i := 0; i < concurrency; i++ { + wait.Add(1) + go func(w *sync.WaitGroup, m *sync.Map) { + for j := 0; j < b.N; j++ { + l := rand.Intn(dataLen) + if l%2 == 0 { + m.Load(strconv.Itoa(l)) + } else { + m.Store(strconv.Itoa(l), l+1) + } + } + w.Done() + }(wait, m) + } + wait.Wait() +} + +func BenchmarkMapMutex_ReadWrite(b *testing.B) { + m := make(map[string]int) + lock := sync.Mutex{} + dataLen := 65536 + for i := 0; i < dataLen; i++ { + m[strconv.Itoa(i)] = i + } + concurrency := 8 + data := make([]string, dataLen) + for i := 0; i < dataLen; i++ { + data[i] = strconv.Itoa(i) + } + wait := &sync.WaitGroup{} + + b.ResetTimer() + for i := 0; i < concurrency; i++ { + wait.Add(1) + go func(w *sync.WaitGroup, m map[string]int, ) { + for j := 0; j < b.N; j++ { + l := rand.Intn(dataLen) + lock.Lock() + if l%2 == 0 { + _ = m[strconv.Itoa(l)] + } else { + m[strconv.Itoa(l)] = l+1 + } + lock.Unlock() + } + w.Done() + }(wait, m) + } + wait.Wait() +} + +func TestConcurrentMap_DelAndLength(t *testing.T) { + m := NewConcurrentMap() + w := &sync.WaitGroup{} + l := 1000 + c := 8 + for i := 0; i < c; i++ { + w.Add(1) + go func(start int, wait *sync.WaitGroup) { + for j := start * l; j < (start+1)*l; j++ { + m.Put(strconv.Itoa(j), j) + } + wait.Done() + }(i, w) + } + w.Wait() + + if m.Length() != l*c { + t.Errorf("Length not expected. Length returns %d, expected %d", m.Length(), l*c) + } + + for i := 0; i < c; i++ { + w.Add(1) + go func(start int, wait *sync.WaitGroup) { + for j := start * l; j < (start+1)*l; j++ { + m.Del(strconv.Itoa(j)) + } + wait.Done() + }(i, w) + } + w.Wait() + if m.Length() != 0 { + t.Errorf("Length not expected. Length returns %d, expected %d", m.Length(), 0) + } +} + +func TestConcurrentMap_PutAndGet(t *testing.T) { + m := NewConcurrentMap() + w := &sync.WaitGroup{} + l := 10000 + c := 8 + workLoad := l / c + + for i := 0; i < c; i++ { + w.Add(1) + go func(w *sync.WaitGroup, start, end int) { + for j := start; j < end; j++ { + m.Put(strconv.Itoa(j), j) + } + w.Done() + }(w, i*workLoad, (i+1)*workLoad) + } + w.Wait() + + for i := 0; i < c; i++ { + w.Add(1) + go func(w *sync.WaitGroup, start, end int) { + for j := start; j < end; j++ { + tmp, ok := m.Get(strconv.Itoa(j)) + if !ok || tmp.(int) != j { + t.Errorf("Get not expexted. Get return: %v, expected:%d", tmp, j) + } + } + w.Done() + }(w, i*workLoad, (i+1)*workLoad) + } + w.Wait() +} + +func TestConcurrentMap_Iterate(t *testing.T) { + m := NewConcurrentMap() + l := 10000 + type Value struct { + Num int + } + for i := 0; i < l; i++ { + m.Put(strconv.Itoa(i), &Value{Num: i}) + } + m.Iterate(func(key string, val interface{}) bool { + val.(*Value).Num++ + return true + }) + for i := 0; i < l; i++ { + tmp, _ := m.Get(strconv.Itoa(i)) + if tmp.(*Value).Num != i+1 { + t.Errorf("Iterate not expected.") + break + } + } +} + +func TestConcurrentMap_Filter(t *testing.T) { + m := NewConcurrentMap() + l := 10000 + type Value struct { + Num int + } + for i := 0; i < l; i++ { + m.Put(strconv.Itoa(i), &Value{Num: i}) + } + m.Filter(func(key string, val interface{}) bool { + return val.(*Value).Num%2 == 0 + }) + for i := 0; i < l; i++ { + _, ok := m.Get(strconv.Itoa(i)) + isEven := i%2 == 0 + if (isEven && !ok) || (!isEven && ok) { + t.Errorf("Filter not expected") + + } + } +} diff --git a/sess.go b/sess.go index 7bce1263..fd11a6e7 100644 --- a/sess.go +++ b/sess.go @@ -775,8 +775,7 @@ type ( conn net.PacketConn // the underlying packet connection ownConn bool // true if we created conn internally, false if provided by caller - sessions map[string]*UDPSession // all sessions accepted by this Listener - sessionLock sync.Mutex + sessions *ConcurrentMap // all sessions accepted by this Listener chAccepts chan *UDPSession // Listen() backlog chSessionClosed chan net.Addr // session close queue headerSize int // the additional header to a KCP frame @@ -811,10 +810,8 @@ func (l *Listener) packetInput(data []byte, addr net.Addr) { } if dataValid { - l.sessionLock.Lock() - s, ok := l.sessions[addr.String()] - l.sessionLock.Unlock() - + tmp, ok := l.sessions.Get(addr.String()) + s, _ := tmp.(*UDPSession) var conv, sn uint32 convValid := false if l.fecDecoder != nil { @@ -843,9 +840,7 @@ func (l *Listener) packetInput(data []byte, addr net.Addr) { if len(l.chAccepts) < cap(l.chAccepts) { // do not let the new sessions overwhelm accept queue s := newUDPSession(conv, l.dataShards, l.parityShards, l, l.conn, false, addr, l.block) s.kcpInput(data) - l.sessionLock.Lock() - l.sessions[addr.String()] = s - l.sessionLock.Unlock() + l.sessions.Put(addr.String(), s) l.chAccepts <- s } } @@ -858,11 +853,10 @@ func (l *Listener) notifyReadError(err error) { close(l.chSocketReadError) // propagate read error to all sessions - l.sessionLock.Lock() - for _, s := range l.sessions { - s.notifyReadError(err) - } - l.sessionLock.Unlock() + l.sessions.Iterate(func(key string, val interface{}) bool { + val.(*UDPSession).notifyReadError(err) + return true + }) }) } @@ -969,12 +963,12 @@ func (l *Listener) Close() error { // closeSession notify the listener that a session has closed func (l *Listener) closeSession(remote net.Addr) (ret bool) { - l.sessionLock.Lock() - defer l.sessionLock.Unlock() - if _, ok := l.sessions[remote.String()]; ok { - delete(l.sessions, remote.String()) + l.sessions.Filter(func(key string, val interface{}) bool { + if key == remote.String() { + return false + } return true - } + }) return false } @@ -1013,7 +1007,7 @@ func serveConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketCo l := new(Listener) l.conn = conn l.ownConn = ownConn - l.sessions = make(map[string]*UDPSession) + l.sessions = NewConcurrentMap() l.chAccepts = make(chan *UDPSession, acceptBacklog) l.chSessionClosed = make(chan net.Addr) l.die = make(chan struct{})