From 69e59822f09d1d958fb190f59191b73441f59ecd Mon Sep 17 00:00:00 2001 From: siddharth ravikumar Date: Mon, 26 Dec 2022 16:40:40 -0500 Subject: cache: make it concurrency safe --- cache/cache.go | 15 +++++++++++++++ cache/cache_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/cache/cache.go b/cache/cache.go index 8f29ee4..8d5e8f8 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -14,12 +14,14 @@ type item struct { // A key-value cache store. type Cache struct { + sema chan int // Semaphore for read/write access to cache. store map[string]item } // Returns a new empty cache store. func NewCache() *Cache { c := new(Cache) + c.sema = make(chan int, 1) c.store = make(map[string]item) return c } @@ -30,6 +32,12 @@ func NewCache() *Cache { // Cache.Get will return an empty string once `expires` is past the // current time. func (c *Cache) Set(key string, value []byte, expires time.Time) { + // Get sema token before accessing the cache. + c.sema <- 1 + defer func() { + // Give up sema token. + <-c.sema + }() c.store[key] = item{ value: value, expires: expires, @@ -41,6 +49,13 @@ func (c *Cache) Set(key string, value []byte, expires time.Time) { // An empty []byte will be returned when if the key does not exist or // if the item corresponding to the key has expired. func (c *Cache) Get(key string) []byte { + // Get sema token before accessing the cache. + c.sema <- 1 + defer func() { + // Give up sema token. + <-c.sema + }() + if _, ok := c.store[key]; !ok { return []byte{} } diff --git a/cache/cache_test.go b/cache/cache_test.go index 4751260..c663abf 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -5,6 +5,7 @@ package cache import ( "bytes" + "fmt" "testing" "time" ) @@ -71,3 +72,45 @@ func TestCacheGet(t *testing.T) { return } } + +func TestConcurrentSets(t *testing.T) { + // Expiration time for all keys. + exp := time.Now().Add(time.Second * 120) + + // Generate some keys. + keys := make([]string, 0) + maxKeys := 1000 + for i := 0; i < maxKeys; i++ { + keys = append(keys, fmt.Sprintf("key-%d", i)) + } + + // Go routing for adding keys to cache. + addToCache := func(c *Cache, keys []string, donec chan int) { + for i := 0; i < len(keys); i++ { + c.Set(keys[i], []byte(fmt.Sprintf("val-%d", i)), exp) + } + donec <- 1 + } + + // Init. cache. + c := NewCache() + if c == nil { + t.Errorf("cache is nil") + return + } + donec := make(chan int) + + // Add keys to cache concurrently. + go addToCache(c, keys, donec) + go addToCache(c, keys, donec) + go addToCache(c, keys, donec) + completed := 0 + for completed < 3 { + <-donec + completed += 1 + } + + if len(c.store) != maxKeys { + t.Errorf("number of keys in store != %d: %v", maxKeys, c.store) + } +} -- cgit v1.2.3