diff --git a/cache/cache.go b/cache/cache.go index 5831f3c..84d99bb 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -1,137 +1,119 @@ package cache import ( + "context" "log" "sync" "time" ) -type item struct { - sync.Mutex - Value any - Duration time.Duration - Expiration int64 - Regenerate func() (any, error) -} +var valueKey int -func (i *item) Expired() bool { - if i.Duration == 0 { - return false +func newContext(value any, lifecycle time.Duration) (ctx context.Context, cancel context.CancelFunc) { + ctx = context.WithValue(context.Background(), &valueKey, value) + if lifecycle > 0 { + ctx, cancel = context.WithTimeout(ctx, lifecycle) } - - return time.Now().UnixNano() > i.Expiration + return } -// Cache is cache struct. -type Cache struct { - cache sync.Map - autoClean bool +type item[T any] struct { + sync.Mutex + context.Context + cancel context.CancelFunc + lifecycle time.Duration + fn func() (T, error) } -// New creates a new cache with auto clean or not. -func New(autoClean bool) *Cache { - c := &Cache{autoClean: autoClean} +func (i *item[T]) value() T { + i.Lock() + defer i.Unlock() + return i.Value(&valueKey).(T) +} - if autoClean { - go c.check() +func (i *item[T]) renew() T { + v, err := i.fn() + if err != nil { + log.Print(err) + v = i.value() } - - return c + i.Lock() + defer i.Unlock() + i.Context, i.cancel = newContext(v, i.lifecycle) + return v } -// Set sets cache value for a key, if f is presented, this value will regenerate when expired. -func (c *Cache) Set(key, value any, d time.Duration, f func() (any, error)) { - c.cache.Store(key, &item{ - Value: value, - Duration: d, - Expiration: time.Now().Add(d).UnixNano(), - Regenerate: f, - }) +// Cache is cache struct. +type Cache[Key, Value any] struct { + cache sync.Map + autoRenew bool } -func (c *Cache) regenerate(i *item) { - i.Expiration = 0 - i.Unlock() - - go func() { - value, err := i.Regenerate() - - i.Lock() - defer i.Unlock() +// New creates a new cache with auto clean or not. +func New[Key, Value any](autoRenew bool) *Cache[Key, Value] { + return &Cache[Key, Value]{autoRenew: autoRenew} +} - if err != nil { - log.Print(err) - } else { - i.Value = value - } - i.Expiration = time.Now().Add(i.Duration).UnixNano() - }() +// Set sets cache value for a key, if fn is presented, this value will regenerate when expired. +func (c *Cache[Key, Value]) Set(key Key, value Value, lifecycle time.Duration, fn func() (Value, error)) { + i := &item[Value]{lifecycle: lifecycle, fn: fn} + i.Context, i.cancel = newContext(value, lifecycle) + if c.autoRenew && lifecycle > 0 { + go func() { + for { + <-i.Done() + if err := i.Err(); err == context.DeadlineExceeded { + if i.fn != nil { + i.renew() + } else { + c.Delete(key) + } + } else { + return + } + } + }() + } + c.cache.Store(key, i) } // Get gets cache value by key and whether value was found. -func (c *Cache) Get(key any) (any, bool) { - value, ok := c.cache.Load(key) +func (c *Cache[Key, Value]) Get(key Key) (Value, bool) { + v, ok := c.cache.Load(key) if !ok { - return nil, false + return *new(Value), false } - - i := value.(*item) - i.Lock() - - if i.Expired() && !c.autoClean { - if i.Regenerate == nil { - c.cache.Delete(key) - i.Unlock() - - return nil, false + if i := v.(*item[Value]); !c.autoRenew && i.Err() == context.DeadlineExceeded { + if i.fn == nil { + c.Delete(key) + return *new(Value), false } - - defer c.regenerate(i) - - return i.Value, true + return i.renew(), true + } else { + return i.value(), true } - - i.Unlock() - - return i.Value, true } // Delete deletes the value for a key. -func (c *Cache) Delete(key any) { - c.cache.Delete(key) +func (c *Cache[Key, Value]) Delete(key Key) { + if v, ok := c.cache.LoadAndDelete(key); ok { + if v, ok := v.(*item[Value]); ok { + if v.cancel != nil { + v.cancel() + } + } + } } // Empty deletes all values in cache. -func (c *Cache) Empty() { - c.cache.Range(func(key, _ any) bool { - c.cache.Delete(key) +func (c *Cache[Key, Value]) Empty() { + c.cache.Range(func(k, v any) bool { + c.cache.Delete(k) + if v, ok := v.(*item[Value]); ok { + if v.cancel != nil { + v.cancel() + } + } return true }) } - -func (c *Cache) check() { - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - - for range ticker.C { - c.cache.Range(func(key, value any) bool { - i := value.(*item) - i.Lock() - - if i.Expired() { - if i.Regenerate == nil { - c.cache.Delete(key) - i.Unlock() - } else { - defer c.regenerate(i) - } - - return true - } - - i.Unlock() - - return true - }) - } -} diff --git a/cache/cache_test.go b/cache/cache_test.go index 2682a2c..f548fd8 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -6,89 +6,67 @@ import ( ) func TestSetGetDelete(t *testing.T) { - cache := New(false) - + cache := New[string, string](false) cache.Set("key", "value", 0, nil) - - value, ok := cache.Get("key") - if !ok { + if value, ok := cache.Get("key"); !ok { t.Fatal("expected ok; got not") - } - if value != "value" { + } else if value != "value" { t.Errorf("expected value; got %q", value) } - cache.Delete("key") - _, ok = cache.Get("key") - if ok { + if _, ok := cache.Get("key"); ok { t.Error("expected not ok; got ok") } } func TestEmpty(t *testing.T) { - cache := New(false) - + cache := New[string, int](false) cache.Set("a", 1, 0, nil) cache.Set("b", 2, 0, nil) cache.Set("c", 3, 0, nil) - for _, i := range []string{"a", "b", "c"} { - _, ok := cache.Get(i) - if !ok { + if _, ok := cache.Get(i); !ok { t.Error("expected ok; got not") } } - cache.Empty() - for _, i := range []string{"a", "b", "c"} { - _, ok := cache.Get(i) - if ok { + if _, ok := cache.Get(i); ok { t.Error("expected not ok; got ok") } } } -func TestAutoCleanRegenerate(t *testing.T) { - cache := New(true) - - done := make(chan bool) - cache.Set("regenerate", "old", 2*time.Second, func() (any, error) { - defer func() { done <- true }() +func TestRenew(t *testing.T) { + cache := New[string, string](true) + expire := make(chan struct{}) + cache.Set("renew", "old", 2*time.Second, func() (string, error) { + defer func() { close(expire) }() return "new", nil }) cache.Set("expire", "value", 1*time.Second, nil) - - value, ok := cache.Get("expire") - if !ok { + if value, ok := cache.Get("expire"); !ok { t.Fatal("expected ok; got not") - } - if expect := "value"; value != expect { + } else if expect := "value"; value != expect { t.Errorf("expected %q; got %q", expect, value) } - - value, ok = cache.Get("regenerate") - if !ok { + if value, ok := cache.Get("renew"); !ok { t.Fatal("expected ok; got not") - } - if expect := "old"; value != expect { + } else if expect := "old"; value != expect { t.Errorf("expected %q; got %q", expect, value) } - ticker := time.NewTicker(4 * time.Second) defer ticker.Stop() - select { - case <-done: + case <-expire: + time.Sleep(100 * time.Millisecond) if _, ok := cache.Get("expire"); ok { t.Error("expected not ok; got ok") } - - value, ok := cache.Get("regenerate") + value, ok := cache.Get("renew") if !ok { t.Fatal("expected ok; got not") - } - if expect := "new"; value != expect { + } else if expect := "new"; value != expect { t.Errorf("expected %q; got %q", expect, value) } case <-ticker.C: diff --git a/counter/listener.go b/counter/listener.go index 616ddc4..eb08185 100644 --- a/counter/listener.go +++ b/counter/listener.go @@ -1,8 +1,6 @@ package counter -import ( - "net" -) +import "net" var ( _ net.Listener = &Listener{} diff --git a/httpsvr/httpsvr.go b/httpsvr/httpsvr.go index d6713e7..4ab2c83 100644 --- a/httpsvr/httpsvr.go +++ b/httpsvr/httpsvr.go @@ -16,7 +16,7 @@ import ( "github.com/sunshineplan/utils/log" ) -var certCache = cache.New(false) +var certCache = cache.New[string, *tls.Certificate](false) var defaultReload = 24 * time.Hour @@ -124,7 +124,7 @@ func (s *Server) run() error { return nil } -func (s *Server) loadCertificate() (any, error) { +func (s *Server) loadCertificate() (*tls.Certificate, error) { cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile) if err != nil { return nil, err @@ -135,7 +135,7 @@ func (s *Server) loadCertificate() (any, error) { func (s *Server) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { v, ok := certCache.Get("cert") if ok { - return v.(*tls.Certificate), nil + return v, nil } cert, err := s.loadCertificate() @@ -148,7 +148,7 @@ func (s *Server) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error } certCache.Set("cert", cert, s.reload, s.loadCertificate) - return cert.(*tls.Certificate), nil + return cert, nil } // Run runs an HTTP server which can be gracefully shut down.