summaryrefslogtreecommitdiffstats
path: root/db
diff options
context:
space:
mode:
Diffstat (limited to 'db')
-rw-r--r--db/db.go36
-rw-r--r--db/db_test.go47
2 files changed, 74 insertions, 9 deletions
diff --git a/db/db.go b/db/db.go
index c13f6d1..c99bbea 100644
--- a/db/db.go
+++ b/db/db.go
@@ -21,7 +21,8 @@ var defaultDBPath string
//
// It's stored on disk as a JSON at `$HOME/.config/fern/db.json
type FernDB struct {
- mutex *sync.Mutex // For writes to `downloaded`
+ // For locking concurrent read/write access downloaded.
+ mutex *sync.RWMutex
// Key: feed-id
// Value: feed-id's entries that were downloaded
downloaded map[string][]string
@@ -54,7 +55,7 @@ func Open() (*FernDB, error) {
if err != nil {
// db does not exist yet; create an empty one.
db := new(FernDB)
- db.mutex = new(sync.Mutex)
+ db.mutex = new(sync.RWMutex)
db.downloaded = make(map[string][]string)
return db, nil
}
@@ -71,7 +72,7 @@ func Open() (*FernDB, error) {
// Unmarshal db into an object.
db := new(FernDB)
- db.mutex = new(sync.Mutex)
+ db.mutex = new(sync.RWMutex)
err = json.Unmarshal(bs, &db.downloaded)
if err != nil {
return nil, err
@@ -79,9 +80,10 @@ func Open() (*FernDB, error) {
return db, nil
}
-// Returns true if an `entry` for `feed` exists in the database; false
-// otherwise.
-func (fdb *FernDB) Exists(feed, entry string) bool {
+// Checks if entry exists in feed. Assumes the current go routine
+// already has the mutex lock. Meant for use by the Exists and Add
+// methods.
+func (fdb *FernDB) exists(feed, entry string) bool {
if _, ok := fdb.downloaded[feed]; !ok {
return false
}
@@ -91,7 +93,16 @@ func (fdb *FernDB) Exists(feed, entry string) bool {
}
}
return false
+}
+
+// Returns true if an `entry` for `feed` exists in the database; false
+// otherwise.
+func (fdb *FernDB) Exists(feed, entry string) bool {
+ // Acquire read lock.
+ fdb.mutex.RLock()
+ defer fdb.mutex.RUnlock() // Give up lock before returning.
+ return fdb.exists(feed, entry)
}
// Adds `feed` <-> `entry` to the database.
@@ -100,24 +111,31 @@ func (fdb *FernDB) Exists(feed, entry string) bool {
// that entry was downloaded and will not try downloading the entry
// again.
func (fdb *FernDB) Add(feed, entry string) {
+ // Acquire write lock.
+ fdb.mutex.Lock()
+ defer fdb.mutex.Unlock() // Give up lock before returning.
+
// Check if entry already exist for feed.
- if fdb.Exists(feed, entry) {
+ if fdb.exists(feed, entry) {
return
}
// Add entry.
- fdb.mutex.Lock()
if _, ok := fdb.downloaded[feed]; !ok {
fdb.downloaded[feed] = make([]string, 0)
}
fdb.downloaded[feed] = append(fdb.downloaded[feed], entry)
- fdb.mutex.Unlock()
+
}
// Writes FernDB to disk in the JSON format.
//
// Returns nil on success; error otherwise
func (fdb *FernDB) Write() error {
+ // Acquire write lock.
+ fdb.mutex.Lock()
+ defer fdb.mutex.Unlock() // Give up lock before returning.
+
if len(dbPath) == 0 {
return fmt.Errorf("FernDB path not set")
}
diff --git a/db/db_test.go b/db/db_test.go
index de27149..966ce85 100644
--- a/db/db_test.go
+++ b/db/db_test.go
@@ -4,6 +4,7 @@
package db
import (
+ "fmt"
"os"
"path"
"testing"
@@ -370,3 +371,49 @@ func TestWriteExistingDB(t *testing.T) {
return
}
}
+
+func TestConcurrentWrites(t *testing.T) {
+ dbPath = path.Join(os.TempDir(), "fern-db.json")
+ defer os.Remove(dbPath)
+ defer resetDBPath()
+
+ db, err := Open()
+ if err != nil {
+ t.Errorf("db open failed: %v", err)
+ return
+ }
+
+ // Randomly create a some entries.
+ numEntries := 1000
+ entries := make([]string, 0)
+ for i := 0; i < numEntries; i++ {
+ entries = append(entries, fmt.Sprintf("entry-%d", i))
+ }
+
+ // Go routine for adding entries to the db.
+ addEntries := func(db *FernDB, feed string, entries []string, donec chan int) {
+ for _, entry := range entries {
+ db.Add(feed, entry)
+ }
+ donec <- 1
+ }
+
+ // Concurrently write entries to a feed.
+ donec := make(chan int)
+ feed := "npr"
+ routines := 5
+ for i := 0; i < routines; i++ {
+ go addEntries(db, feed, entries, donec)
+ }
+ routinesDone := 0
+ for routinesDone != routines {
+ <-donec
+ routinesDone += 1
+ }
+
+ // Check if there are exactly numEntries entries.
+ if len(db.downloaded[feed]) != numEntries {
+ t.Errorf("downloaded entries != %d: %v",
+ numEntries, db.downloaded[feed])
+ }
+}