diff options
author | siddharth ravikumar <s@ricketyspace.net> | 2022-12-25 12:20:17 -0500 |
---|---|---|
committer | siddharth ravikumar <s@ricketyspace.net> | 2022-12-25 12:20:17 -0500 |
commit | 0aee714f3941173893b42c5f0880b7cb6b592548 (patch) | |
tree | fb249748164e54b42290c5747ef9d48f27758e48 | |
parent | b64cce52d9011d9c85bc68041e998f2a0c513867 (diff) |
db: fix data race
- Use a sync.RWMutex for mutex locks.
- Use sync.RWMutex.RLock when reading from the downloaded map.
- Use sync.RWMutex.Lock when writing to the downloaded map.
-rw-r--r-- | db/db.go | 36 | ||||
-rw-r--r-- | db/db_test.go | 47 |
2 files changed, 74 insertions, 9 deletions
@@ -21,7 +21,8 @@ var defaultDBPath string // // It's stored on disk as a JSON at `$HOME/.config/fern/db.json type FernDB struct { - mutex *sync.Mutex // For writes to `downloaded` + // For locking concurrent read/write access downloaded. + mutex *sync.RWMutex // Key: feed-id // Value: feed-id's entries that were downloaded downloaded map[string][]string @@ -54,7 +55,7 @@ func Open() (*FernDB, error) { if err != nil { // db does not exist yet; create an empty one. db := new(FernDB) - db.mutex = new(sync.Mutex) + db.mutex = new(sync.RWMutex) db.downloaded = make(map[string][]string) return db, nil } @@ -71,7 +72,7 @@ func Open() (*FernDB, error) { // Unmarshal db into an object. db := new(FernDB) - db.mutex = new(sync.Mutex) + db.mutex = new(sync.RWMutex) err = json.Unmarshal(bs, &db.downloaded) if err != nil { return nil, err @@ -79,9 +80,10 @@ func Open() (*FernDB, error) { return db, nil } -// Returns true if an `entry` for `feed` exists in the database; false -// otherwise. -func (fdb *FernDB) Exists(feed, entry string) bool { +// Checks if entry exists in feed. Assumes the current go routine +// already has the mutex lock. Meant for use by the Exists and Add +// methods. +func (fdb *FernDB) exists(feed, entry string) bool { if _, ok := fdb.downloaded[feed]; !ok { return false } @@ -91,7 +93,16 @@ func (fdb *FernDB) Exists(feed, entry string) bool { } } return false +} + +// Returns true if an `entry` for `feed` exists in the database; false +// otherwise. +func (fdb *FernDB) Exists(feed, entry string) bool { + // Acquire read lock. + fdb.mutex.RLock() + defer fdb.mutex.RUnlock() // Give up lock before returning. + return fdb.exists(feed, entry) } // Adds `feed` <-> `entry` to the database. @@ -100,24 +111,31 @@ func (fdb *FernDB) Exists(feed, entry string) bool { // that entry was downloaded and will not try downloading the entry // again. func (fdb *FernDB) Add(feed, entry string) { + // Acquire write lock. + fdb.mutex.Lock() + defer fdb.mutex.Unlock() // Give up lock before returning. + // Check if entry already exist for feed. - if fdb.Exists(feed, entry) { + if fdb.exists(feed, entry) { return } // Add entry. - fdb.mutex.Lock() if _, ok := fdb.downloaded[feed]; !ok { fdb.downloaded[feed] = make([]string, 0) } fdb.downloaded[feed] = append(fdb.downloaded[feed], entry) - fdb.mutex.Unlock() + } // Writes FernDB to disk in the JSON format. // // Returns nil on success; error otherwise func (fdb *FernDB) Write() error { + // Acquire write lock. + fdb.mutex.Lock() + defer fdb.mutex.Unlock() // Give up lock before returning. + if len(dbPath) == 0 { return fmt.Errorf("FernDB path not set") } diff --git a/db/db_test.go b/db/db_test.go index de27149..966ce85 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -4,6 +4,7 @@ package db import ( + "fmt" "os" "path" "testing" @@ -370,3 +371,49 @@ func TestWriteExistingDB(t *testing.T) { return } } + +func TestConcurrentWrites(t *testing.T) { + dbPath = path.Join(os.TempDir(), "fern-db.json") + defer os.Remove(dbPath) + defer resetDBPath() + + db, err := Open() + if err != nil { + t.Errorf("db open failed: %v", err) + return + } + + // Randomly create a some entries. + numEntries := 1000 + entries := make([]string, 0) + for i := 0; i < numEntries; i++ { + entries = append(entries, fmt.Sprintf("entry-%d", i)) + } + + // Go routine for adding entries to the db. + addEntries := func(db *FernDB, feed string, entries []string, donec chan int) { + for _, entry := range entries { + db.Add(feed, entry) + } + donec <- 1 + } + + // Concurrently write entries to a feed. + donec := make(chan int) + feed := "npr" + routines := 5 + for i := 0; i < routines; i++ { + go addEntries(db, feed, entries, donec) + } + routinesDone := 0 + for routinesDone != routines { + <-donec + routinesDone += 1 + } + + // Check if there are exactly numEntries entries. + if len(db.downloaded[feed]) != numEntries { + t.Errorf("downloaded entries != %d: %v", + numEntries, db.downloaded[feed]) + } +} |