From c1ace352e8ba09ec3c00317e72de6b871a8964ec Mon Sep 17 00:00:00 2001 From: siddharth Date: Sun, 28 Nov 2021 18:32:18 -0500 Subject: db: add FernDB.Write Writes database to disk. --- db/db.go | 25 ++++++++++++ db/db_test.go | 122 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/db/db.go b/db/db.go index bbe5b36..4ee511b 100644 --- a/db/db.go +++ b/db/db.go @@ -96,3 +96,28 @@ func (fdb *FernDB) Add(feed, entry string) { fdb.downloaded[feed] = append(fdb.downloaded[feed], entry) fdb.mutex.Unlock() } + +func (fdb *FernDB) Write() error { + if len(dbPath) == 0 { + return fmt.Errorf("FernDB path not set") + } + + f, err := os.OpenFile(dbPath, os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + return err + } + defer f.Close() + + // Marshal database into json. + bs, err := json.Marshal(fdb.downloaded) + if err != nil { + return err + } + + // Write to disk. + _, err = f.Write(bs) + if err != nil { + return err + } + return nil +} diff --git a/db/db_test.go b/db/db_test.go index 8fc9fed..de27149 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -248,3 +248,125 @@ func TestAdd(t *testing.T) { return } } + +func TestWriteNewDB(t *testing.T) { + // Set custom path for db. + dbPath = path.Join(os.TempDir(), "fern-db.json") + defer os.Remove(dbPath) + + // Open the db. + db, err := Open() + if err != nil { + t.Errorf("db.Open failed: %v", err.Error()) + return + } + + // Populate db with test data and write to db to disk. + db.Add("npr", "william-prince") + db.Add("npr", "julian-baker") + db.Add("mkbhd", "v-raptor") + db.Write() + + // Read db refreshly from disk and verify the db contents. + db, err = Open() + if err != nil { + t.Errorf("db.Open failed: %v", err.Error()) + return + } + if len(db.downloaded["npr"]) != 2 { + t.Errorf("db.Add failed: expected 2 entries for 'npr'") + return + } + if !db.Exists("npr", "william-prince") { + t.Errorf("db.Add failed: expected %s in 'npr' feed", + "william-prince") + return + } + if !db.Exists("npr", "julian-baker") { + t.Errorf("db.Add failed: expected %s in 'npr' feed", + "julian-baker") + return + } + if len(db.downloaded["mkbhd"]) != 1 { + t.Errorf("db.Add failed: expected 1 entry for 'npr'") + return + } + if !db.Exists("mkbhd", "v-raptor") { + t.Errorf("db.Add failed: expected %s in 'mkbhd' feed", + "v-raptor") + return + } +} + +func TestWriteExistingDB(t *testing.T) { + // Set custom path for db. + dbPath = path.Join(os.TempDir(), "fern-db.json") + defer os.Remove(dbPath) + + // Write a sample test db to fern-db.json + testDBJSON := []byte(`{"npr":["kurt-vile","joy-oladokun"]}`) + dbFile, err := os.Create(dbPath) + defer dbFile.Close() + if err != nil { + t.Errorf("Unable to create fern-db.json: %v", err.Error()) + return + } + n, err := dbFile.Write(testDBJSON) + if len(testDBJSON) != n { + t.Errorf("Write to fern-db.json failed: %v", err.Error()) + return + } + + // Open the db. + db, err := Open() + if err != nil { + t.Errorf("db.Open failed: %v", err.Error()) + return + } + + // Populate db with test data and write to db to disk. + db.Add("npr", "william-prince") + db.Add("npr", "julian-baker") + db.Add("mkbhd", "v-raptor") + db.Write() + + // Read db refreshly from disk and verify the db contents. + db, err = Open() + if err != nil { + t.Errorf("db.Open failed: %v", err.Error()) + return + } + if len(db.downloaded["npr"]) != 4 { + t.Errorf("db.Add failed: expected 2 entries for 'npr'") + return + } + if !db.Exists("npr", "kurt-vile") { + t.Errorf("db.Add failed: expected %s in 'npr' feed", + "kurt-vile") + return + } + if !db.Exists("npr", "joy-oladokun") { + t.Errorf("db.Add failed: expected %s in 'npr' feed", + "joy-oladokun") + return + } + if !db.Exists("npr", "william-prince") { + t.Errorf("db.Add failed: expected %s in 'npr' feed", + "william-prince") + return + } + if !db.Exists("npr", "julian-baker") { + t.Errorf("db.Add failed: expected %s in 'npr' feed", + "julian-baker") + return + } + if len(db.downloaded["mkbhd"]) != 1 { + t.Errorf("db.Add failed: expected 1 entry for 'npr'") + return + } + if !db.Exists("mkbhd", "v-raptor") { + t.Errorf("db.Add failed: expected %s in 'mkbhd' feed", + "v-raptor") + return + } +} -- cgit v1.2.3