summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorsiddharth ravikumar <s@ricketyspace.net>2022-12-26 16:40:40 -0500
committersiddharth ravikumar <s@ricketyspace.net>2022-12-26 16:40:40 -0500
commit69e59822f09d1d958fb190f59191b73441f59ecd (patch)
treea9134dbd1399cd95aa7d07e1b30dc2ae4ebbaf6e
parenta19eb2afbd4c87f7361942eedd7bdc80f3e225cf (diff)
cache: make it concurrency safe
-rw-r--r--cache/cache.go15
-rw-r--r--cache/cache_test.go43
2 files changed, 58 insertions, 0 deletions
diff --git a/cache/cache.go b/cache/cache.go
index 8f29ee4..8d5e8f8 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -14,12 +14,14 @@ type item struct {
// A key-value cache store.
type Cache struct {
+ sema chan int // Semaphore for read/write access to cache.
store map[string]item
}
// Returns a new empty cache store.
func NewCache() *Cache {
c := new(Cache)
+ c.sema = make(chan int, 1)
c.store = make(map[string]item)
return c
}
@@ -30,6 +32,12 @@ func NewCache() *Cache {
// Cache.Get will return an empty string once `expires` is past the
// current time.
func (c *Cache) Set(key string, value []byte, expires time.Time) {
+ // Get sema token before accessing the cache.
+ c.sema <- 1
+ defer func() {
+ // Give up sema token.
+ <-c.sema
+ }()
c.store[key] = item{
value: value,
expires: expires,
@@ -41,6 +49,13 @@ func (c *Cache) Set(key string, value []byte, expires time.Time) {
// An empty []byte will be returned when if the key does not exist or
// if the item corresponding to the key has expired.
func (c *Cache) Get(key string) []byte {
+ // Get sema token before accessing the cache.
+ c.sema <- 1
+ defer func() {
+ // Give up sema token.
+ <-c.sema
+ }()
+
if _, ok := c.store[key]; !ok {
return []byte{}
}
diff --git a/cache/cache_test.go b/cache/cache_test.go
index 4751260..c663abf 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -5,6 +5,7 @@ package cache
import (
"bytes"
+ "fmt"
"testing"
"time"
)
@@ -71,3 +72,45 @@ func TestCacheGet(t *testing.T) {
return
}
}
+
+func TestConcurrentSets(t *testing.T) {
+ // Expiration time for all keys.
+ exp := time.Now().Add(time.Second * 120)
+
+ // Generate some keys.
+ keys := make([]string, 0)
+ maxKeys := 1000
+ for i := 0; i < maxKeys; i++ {
+ keys = append(keys, fmt.Sprintf("key-%d", i))
+ }
+
+ // Go routing for adding keys to cache.
+ addToCache := func(c *Cache, keys []string, donec chan int) {
+ for i := 0; i < len(keys); i++ {
+ c.Set(keys[i], []byte(fmt.Sprintf("val-%d", i)), exp)
+ }
+ donec <- 1
+ }
+
+ // Init. cache.
+ c := NewCache()
+ if c == nil {
+ t.Errorf("cache is nil")
+ return
+ }
+ donec := make(chan int)
+
+ // Add keys to cache concurrently.
+ go addToCache(c, keys, donec)
+ go addToCache(c, keys, donec)
+ go addToCache(c, keys, donec)
+ completed := 0
+ for completed < 3 {
+ <-donec
+ completed += 1
+ }
+
+ if len(c.store) != maxKeys {
+ t.Errorf("number of keys in store != %d: %v", maxKeys, c.store)
+ }
+}