diff options
Diffstat (limited to 'db')
-rw-r--r-- | db/db.go | 66 | ||||
-rw-r--r-- | db/db_test.go | 123 |
2 files changed, 189 insertions, 0 deletions
diff --git a/db/db.go b/db/db.go new file mode 100644 index 0000000..2d6a4d5 --- /dev/null +++ b/db/db.go @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: ISC +// Copyright © 2021 siddharth <s@ricketyspace.net> + +package db + +import ( + "encoding/json" + "fmt" + "os" + "path" + + "ricketyspace.net/fern/file" +) + +var dbPath string + +type FernDB struct { + // Key: feed-id + // Value: feed-id's entries that were downloaded + downloaded map[string][]string +} + +func init() { + dbPath = "" // Reset. + + // Construct default dbPath + h, err := os.UserHomeDir() + if err != nil { + return + } + dbPath = path.Join(h, ".config", "fern", "db.json") + +} + +func Open() (*FernDB, error) { + if len(dbPath) == 0 { + return nil, fmt.Errorf("FernDB path not set") + } + + // Check if db exists. + _, err := os.Stat(dbPath) + if err != nil { + // db does not exist yet; create an empty one. + db := new(FernDB) + db.downloaded = make(map[string][]string) + return db, nil + } + + // Read db from disk. + f, err := os.Open(dbPath) + if err != nil { + return nil, err + } + bs, err := file.Read(f) + if err != nil { + return nil, err + } + + // Unmarshal db into an object. + db := new(FernDB) + err = json.Unmarshal(bs, &db.downloaded) + if err != nil { + return nil, err + } + return db, nil +} diff --git a/db/db_test.go b/db/db_test.go new file mode 100644 index 0000000..5a2a983 --- /dev/null +++ b/db/db_test.go @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: ISC +// Copyright © 2021 siddharth <s@ricketyspace.net> + +package db + +import ( + "os" + "path" + "testing" +) + +func stringsContain(haystack []string, needle string) bool { + for _, s := range haystack { + if s == needle { + return true + } + } + return false +} + +func TestOpenPathNotSet(t *testing.T) { + // Set custom path for db. + dbPath = "" + defer os.Remove(dbPath) + + _, err := Open() + if err == nil { + t.Errorf("Error: db.Open did not fail when dbPath is empty\n") + return + } + if err.Error() != "FernDB path not set" { + t.Errorf("Error: db.Open wrong error message when dbPath is empty\n") + return + } +} + +func TestOpenNewDB(t *testing.T) { + // Set custom path for db. + dbPath = path.Join(os.TempDir(), "fern-db.json") + defer os.Remove(dbPath) + + // Open empty db. + db, err := Open() + if err != nil { + t.Errorf("db.Open failed: %v", err.Error()) + return + } + + // Verify that 'downloaded' is initialized + if db.downloaded == nil { + t.Errorf("db.downloaded is nil") + return + } +} + +func TestOpenExistingDB(t *testing.T) { + // Set custom path for db. + dbPath = path.Join(os.TempDir(), "fern-db.json") + defer os.Remove(dbPath) + + // Write a sample test db to fern-db.json + testDBJSON := []byte(`{"mkbhd":["rivian","v-raptor","m1-imac"],"npr":["william-prince","joy-oladokun","lucy-ducas"],"simone":["weightless","ugly-desks","safety-hat"]}`) + dbFile, err := os.Create(dbPath) + if err != nil { + t.Errorf("Unable to create fern-db.json: %v", err.Error()) + } + n, err := dbFile.Write(testDBJSON) + if len(testDBJSON) != n { + t.Errorf("Write to fern-db.json failed: %v", err.Error()) + } + dbFile.Close() + + // Open the db. + db, err := Open() + if err != nil { + t.Errorf("db.Open failed: %v", err.Error()) + return + } + + // Validate db.downloaded. + var entries, expectedEntries []string + var ok bool + if len(db.downloaded) != 3 { + t.Errorf("db.downloaded does not contain 3 feeds") + return + } + // mkbhd + if entries, ok = db.downloaded["mkbhd"]; !ok { + t.Errorf("db.downloaded does not contain mkbhd") + return + } + expectedEntries = []string{"rivian", "v-raptor", "m1-imac"} + for _, entry := range entries { + if !stringsContain(expectedEntries, entry) { + t.Errorf("%v does not exist in db.downloaded[mkbhd]", entry) + return + } + } + // simone + if entries, ok = db.downloaded["simone"]; !ok { + t.Errorf("db.downloaded does not contain simone") + return + } + expectedEntries = []string{"weightless", "ugly-desks", "safety-hat"} + for _, entry := range entries { + if !stringsContain(expectedEntries, entry) { + t.Errorf("%v does not exist in db.downloaded[simone]", entry) + return + } + } + // npr + if entries, ok = db.downloaded["npr"]; !ok { + t.Errorf("db.downloaded does not contain npr") + return + } + expectedEntries = []string{"william-prince", "lucy-ducas", "joy-oladokun"} + for _, entry := range entries { + if !stringsContain(expectedEntries, entry) { + t.Errorf("%v does not exist in db.downloaded[npr]", entry) + return + } + } +} |