mirror of
https://github.com/Team254/cheesy-arena-lite.git
synced 2026-03-09 21:56:50 -04:00
Add table wrapper struct for using Bolt DB instead of SQLite.
This commit is contained in:
@@ -21,6 +21,7 @@ func SetupTestArena(t *testing.T, uniqueName string) *Arena {
|
|||||||
model.BaseDir = ".."
|
model.BaseDir = ".."
|
||||||
dbPath := filepath.Join(model.BaseDir, fmt.Sprintf("%s_test.db", uniqueName))
|
dbPath := filepath.Join(model.BaseDir, fmt.Sprintf("%s_test.db", uniqueName))
|
||||||
os.Remove(dbPath)
|
os.Remove(dbPath)
|
||||||
|
os.Remove(dbPath + ".bolt")
|
||||||
arena, err := NewArena(dbPath)
|
arena, err := NewArena(dbPath)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
return arena
|
return arena
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -19,5 +19,6 @@ require (
|
|||||||
github.com/mitchellh/mapstructure v1.4.1
|
github.com/mitchellh/mapstructure v1.4.1
|
||||||
github.com/stretchr/testify v1.7.0
|
github.com/stretchr/testify v1.7.0
|
||||||
github.com/ziutek/mymysql v1.5.4 // indirect
|
github.com/ziutek/mymysql v1.5.4 // indirect
|
||||||
|
go.etcd.io/bbolt v1.3.5 // indirect
|
||||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
|
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
|
||||||
)
|
)
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -44,6 +44,8 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc
|
|||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs=
|
github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs=
|
||||||
github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0=
|
github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0=
|
||||||
|
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
|
||||||
|
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g=
|
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g=
|
||||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||||
@@ -52,6 +54,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
|
|||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4=
|
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4=
|
||||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 h1:LfCXLvNmTYH9kEmVgqbnsWfruoXZIrh4YBgqVHtDvw0=
|
||||||
|
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221 h1:/ZHdbVpdR/jk3g30/d4yUL0JU9kksj8+F/bnQUVLGDM=
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221 h1:/ZHdbVpdR/jk3g30/d4yUL0JU9kksj8+F/bnQUVLGDM=
|
||||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/jmoiron/modl"
|
"github.com/jmoiron/modl"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"go.etcd.io/bbolt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -23,6 +25,7 @@ const backupsDir = "db/backups"
|
|||||||
const migrationsDir = "db/migrations"
|
const migrationsDir = "db/migrations"
|
||||||
|
|
||||||
var BaseDir = "." // Mutable for testing
|
var BaseDir = "." // Mutable for testing
|
||||||
|
var recordTypes = []interface{}{}
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
Path string
|
Path string
|
||||||
@@ -38,6 +41,8 @@ type Database struct {
|
|||||||
scheduleBlockMap *modl.DbMap
|
scheduleBlockMap *modl.DbMap
|
||||||
awardMap *modl.DbMap
|
awardMap *modl.DbMap
|
||||||
userSessionMap *modl.DbMap
|
userSessionMap *modl.DbMap
|
||||||
|
bolt *bbolt.DB
|
||||||
|
tables map[interface{}]*table
|
||||||
}
|
}
|
||||||
|
|
||||||
// Opens the SQLite database at the given path, creating it if it doesn't exist, and runs any pending
|
// Opens the SQLite database at the given path, creating it if it doesn't exist, and runs any pending
|
||||||
@@ -64,11 +69,29 @@ func OpenDatabase(filename string) (*Database, error) {
|
|||||||
database.db = db
|
database.db = db
|
||||||
database.mapTables()
|
database.mapTables()
|
||||||
|
|
||||||
|
database.bolt, err = bbolt.Open(database.Path+".bolt", 0644, &bbolt.Options{NoSync: true, Timeout: time.Second})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register tables.
|
||||||
|
database.tables = make(map[interface{}]*table)
|
||||||
|
for _, recordType := range recordTypes {
|
||||||
|
table, err := database.newTable(recordType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
database.tables[recordType] = table
|
||||||
|
}
|
||||||
|
|
||||||
return &database, nil
|
return &database, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (database *Database) Close() {
|
func (database *Database) Close() {
|
||||||
database.db.Close()
|
database.db.Close()
|
||||||
|
if err := database.bolt.Close(); err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a copy of the current database and saves it to the backups directory.
|
// Creates a copy of the current database and saves it to the backups directory.
|
||||||
|
|||||||
282
model/table.go
Normal file
282
model/table.go
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
// Copyright 2021 Team 254. All Rights Reserved.
|
||||||
|
// Author: pat@patfairbank.com (Patrick Fairbank)
|
||||||
|
//
|
||||||
|
// Defines a "table" wrapper struct and helper methods for persisting data using Bolt.
|
||||||
|
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"go.etcd.io/bbolt"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Encapsulates all persistence operations for a particular data type represented by a struct.
|
||||||
|
type table struct {
|
||||||
|
bolt *bbolt.DB
|
||||||
|
recordType reflect.Type
|
||||||
|
name string
|
||||||
|
bucketKey []byte
|
||||||
|
idFieldIndex *int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registers a new table for a struct, given its zero value.
|
||||||
|
func (database *Database) newTable(recordType interface{}) (*table, error) {
|
||||||
|
recordTypeValue := reflect.ValueOf(recordType)
|
||||||
|
if recordTypeValue.Kind() != reflect.Struct {
|
||||||
|
return nil, fmt.Errorf("record type must be a struct; got %v", recordTypeValue.Kind())
|
||||||
|
}
|
||||||
|
|
||||||
|
var table table
|
||||||
|
table.bolt = database.bolt
|
||||||
|
table.recordType = reflect.TypeOf(recordType)
|
||||||
|
table.name = table.recordType.Name()
|
||||||
|
table.bucketKey = []byte(table.name)
|
||||||
|
|
||||||
|
// Determine which field in the struct is tagged as the ID and cache its index.
|
||||||
|
idFound := false
|
||||||
|
for i := 0; i < recordTypeValue.Type().NumField(); i++ {
|
||||||
|
field := recordTypeValue.Type().Field(i)
|
||||||
|
tag := field.Tag.Get("db")
|
||||||
|
if tag == "id" {
|
||||||
|
if field.Type.Kind() != reflect.Int64 {
|
||||||
|
return nil,
|
||||||
|
fmt.Errorf(
|
||||||
|
"field in struct %s tagged with 'id' must be an int64; got %v", table.name, field.Type.Kind(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
table.idFieldIndex = new(int)
|
||||||
|
*table.idFieldIndex = i
|
||||||
|
idFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !idFound {
|
||||||
|
return nil, fmt.Errorf("struct %s has no field tagged as the id", table.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the Bolt bucket corresponding to the struct.
|
||||||
|
err := table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||||
|
_, err := tx.CreateBucketIfNotExists(table.bucketKey)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &table, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populates the given double pointer to a record with the data from the record with the given ID, or nil if it doesn't
|
||||||
|
// exist.
|
||||||
|
func (table *table) getById(id int64, record interface{}) error {
|
||||||
|
if err := table.validateType(record, reflect.Ptr, reflect.Ptr, reflect.Struct); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return table.bolt.View(func(tx *bbolt.Tx) error {
|
||||||
|
bucket, err := table.getBucket(tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if recordJson := bucket.Get(idToKey(id)); recordJson != nil {
|
||||||
|
return json.Unmarshal(recordJson, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the record does not exist, set the record pointer to nil.
|
||||||
|
recordPointerValue := reflect.ValueOf(record).Elem()
|
||||||
|
recordPointerValue.Set(reflect.Zero(recordPointerValue.Type()))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populates the given slice passed by pointer with the data from every record in the table, ordered by ID.
|
||||||
|
func (table *table) getAll(recordSlice interface{}) error {
|
||||||
|
if err := table.validateType(recordSlice, reflect.Ptr, reflect.Slice, reflect.Struct); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return table.bolt.View(func(tx *bbolt.Tx) error {
|
||||||
|
bucket, err := table.getBucket(tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
recordSliceValue := reflect.ValueOf(recordSlice).Elem()
|
||||||
|
recordSliceValue.Set(reflect.MakeSlice(recordSliceValue.Type(), 0, 0))
|
||||||
|
return bucket.ForEach(func(key, value []byte) error {
|
||||||
|
record := reflect.New(table.recordType)
|
||||||
|
err := json.Unmarshal(value, record.Interface())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
recordSliceValue.Set(reflect.Append(recordSliceValue, record.Elem()))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persists the given record as a new row in the table.
|
||||||
|
func (table *table) create(record interface{}) error {
|
||||||
|
if err := table.validateType(record, reflect.Ptr, reflect.Struct); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that the record has its ID set to zero since it will be given an auto-generated one.
|
||||||
|
value := reflect.ValueOf(record).Elem()
|
||||||
|
id := value.Field(*table.idFieldIndex).Int()
|
||||||
|
if id != 0 {
|
||||||
|
return fmt.Errorf("can't create %s with non-zero ID: %d", table.name, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||||
|
bucket, err := table.getBucket(tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new ID for the record.
|
||||||
|
newSequence, err := bucket.NextSequence()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
id = int64(newSequence)
|
||||||
|
value.Field(*table.idFieldIndex).SetInt(id)
|
||||||
|
|
||||||
|
// Ensure that a record having the same ID does not already exist in the table.
|
||||||
|
key := idToKey(id)
|
||||||
|
oldRecord := bucket.Get(key)
|
||||||
|
if oldRecord != nil {
|
||||||
|
return fmt.Errorf("%s with ID %d already exists: %s", table.name, id, string(oldRecord))
|
||||||
|
}
|
||||||
|
|
||||||
|
recordJson, err := json.Marshal(record)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return bucket.Put(key, recordJson)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persists the given record as an update to the existing row in the table. Returns an error if the record does not
|
||||||
|
// already exist.
|
||||||
|
func (table *table) update(record interface{}) error {
|
||||||
|
if err := table.validateType(record, reflect.Ptr, reflect.Struct); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that the record has a non-zero ID.
|
||||||
|
value := reflect.ValueOf(record).Elem()
|
||||||
|
id := value.Field(*table.idFieldIndex).Int()
|
||||||
|
if id == 0 {
|
||||||
|
return fmt.Errorf("can't update %s with zero ID", table.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||||
|
bucket, err := table.getBucket(tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that a record having the same ID exists in the table.
|
||||||
|
key := idToKey(id)
|
||||||
|
oldRecord := bucket.Get(key)
|
||||||
|
if oldRecord == nil {
|
||||||
|
return fmt.Errorf("can't update non-existent %s with ID %d", table.name, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
recordJson, err := json.Marshal(record)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return bucket.Put(key, recordJson)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deletes the record having the given ID from the table. Returns an error if the record does not exist.
|
||||||
|
func (table *table) delete(id int64) error {
|
||||||
|
return table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||||
|
bucket, err := table.getBucket(tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that a record having the same ID exists in the table.
|
||||||
|
key := idToKey(id)
|
||||||
|
oldRecord := bucket.Get(key)
|
||||||
|
if oldRecord == nil {
|
||||||
|
return fmt.Errorf("can't delete non-existent %s with ID %d", table.name, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bucket.Delete(key)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deletes all records from the table.
|
||||||
|
func (table *table) truncate() error {
|
||||||
|
return table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||||
|
_, err := table.getBucket(tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Carry out the truncation by way of deleting the whole bucket and then recreate it.
|
||||||
|
err = tx.DeleteBucket(table.bucketKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.CreateBucket(table.bucketKey)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtains the Bolt bucket belonging to the table.
|
||||||
|
func (table *table) getBucket(tx *bbolt.Tx) (*bbolt.Bucket, error) {
|
||||||
|
bucket := tx.Bucket(table.bucketKey)
|
||||||
|
if bucket == nil {
|
||||||
|
return nil, fmt.Errorf("unknown table %s", table.name)
|
||||||
|
}
|
||||||
|
return bucket, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates that the given record is of the expected derived type (e.g. pointer, slice, etc.), that the base type is
|
||||||
|
// the same as that stored in the table, and that the table is configured correctly.
|
||||||
|
func (table *table) validateType(record interface{}, kinds ...reflect.Kind) error {
|
||||||
|
// Check the hierarchy of kinds against the expected list until reaching the base record type.
|
||||||
|
recordType := reflect.ValueOf(record).Type()
|
||||||
|
expectedKind := ""
|
||||||
|
actualKind := ""
|
||||||
|
for i, kind := range kinds {
|
||||||
|
if i > 0 {
|
||||||
|
expectedKind += " -> "
|
||||||
|
actualKind += " -> "
|
||||||
|
}
|
||||||
|
expectedKind += kind.String()
|
||||||
|
actualKind += recordType.Kind().String()
|
||||||
|
if recordType.Kind() != kind {
|
||||||
|
return fmt.Errorf("input must be a %s; got a %s", expectedKind, actualKind)
|
||||||
|
}
|
||||||
|
if i < len(kinds)-1 {
|
||||||
|
recordType = recordType.Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if recordType != table.recordType {
|
||||||
|
return fmt.Errorf("given record of type %s does not match expected type for table %s", recordType, table.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if table.idFieldIndex == nil {
|
||||||
|
return fmt.Errorf("struct %s has no field tagged as the id", table.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serializes the given integer ID to a byte array containing its Base-10 string representation.
|
||||||
|
func idToKey(id int64) []byte {
|
||||||
|
return []byte(strconv.FormatInt(id, 10))
|
||||||
|
}
|
||||||
208
model/table_test.go
Normal file
208
model/table_test.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
// Copyright 2021 Team 254. All Rights Reserved.
|
||||||
|
// Author: pat@patfairbank.com (Patrick Fairbank)
|
||||||
|
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type validRecord struct {
|
||||||
|
Id int64 `db:"id"`
|
||||||
|
IntData int
|
||||||
|
StringData string
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTableSingleCrud(t *testing.T) {
|
||||||
|
db := setupTestDb(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
table, err := db.newTable(validRecord{})
|
||||||
|
if !assert.Nil(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test initial create and then read back.
|
||||||
|
record := validRecord{IntData: 254, StringData: "The Cheesy Poofs"}
|
||||||
|
if assert.Nil(t, table.create(&record)) {
|
||||||
|
assert.Equal(t, int64(1), record.Id)
|
||||||
|
}
|
||||||
|
var record2 *validRecord
|
||||||
|
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||||
|
assert.Equal(t, record, *record2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test update and then read back.
|
||||||
|
record.IntData = 252
|
||||||
|
record.StringData = "Teh Chezy Pofs"
|
||||||
|
assert.Nil(t, table.update(&record))
|
||||||
|
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||||
|
assert.Equal(t, record, *record2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test delete.
|
||||||
|
assert.Nil(t, table.delete(record.Id))
|
||||||
|
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||||
|
assert.Nil(t, record2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTableMultipleCrud(t *testing.T) {
|
||||||
|
db := setupTestDb(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
table, err := db.newTable(validRecord{})
|
||||||
|
if !assert.Nil(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert a few test records.
|
||||||
|
record1 := validRecord{IntData: 1, StringData: "One"}
|
||||||
|
record2 := validRecord{IntData: 2, StringData: "Two"}
|
||||||
|
record3 := validRecord{IntData: 3, StringData: "Three"}
|
||||||
|
assert.Nil(t, table.create(&record1))
|
||||||
|
assert.Nil(t, table.create(&record2))
|
||||||
|
assert.Nil(t, table.create(&record3))
|
||||||
|
|
||||||
|
// Read all records.
|
||||||
|
var records []validRecord
|
||||||
|
assert.Nil(t, table.getAll(&records))
|
||||||
|
if assert.Equal(t, 3, len(records)) {
|
||||||
|
assert.Equal(t, record1, records[0])
|
||||||
|
assert.Equal(t, record2, records[1])
|
||||||
|
assert.Equal(t, record3, records[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate the table and verify that the records no longer exist.
|
||||||
|
assert.Nil(t, table.truncate())
|
||||||
|
assert.Nil(t, table.getAll(&records))
|
||||||
|
assert.Equal(t, 0, len(records))
|
||||||
|
var record4 *validRecord
|
||||||
|
if assert.Nil(t, table.getById(record1.Id, &record4)) {
|
||||||
|
assert.Nil(t, record4)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewTableErrors(t *testing.T) {
|
||||||
|
db := setupTestDb(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Pass a non-struct as the record type.
|
||||||
|
table, err := db.newTable(123)
|
||||||
|
assert.Nil(t, table)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "record type must be a struct; got int", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass a struct that doesn't have an ID field.
|
||||||
|
type recordWithNoId struct {
|
||||||
|
StringData string
|
||||||
|
}
|
||||||
|
table, err = db.newTable(recordWithNoId{})
|
||||||
|
assert.Nil(t, table)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "struct recordWithNoId has no field tagged as the id", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass a struct that has a field with the wrong type tagged as the ID.
|
||||||
|
type recordWithWrongIdType struct {
|
||||||
|
Id bool `db:"id"`
|
||||||
|
}
|
||||||
|
table, err = db.newTable(recordWithWrongIdType{})
|
||||||
|
assert.Nil(t, table)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(
|
||||||
|
t, "field in struct recordWithWrongIdType tagged with 'id' must be an int64; got bool", err.Error(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTableCrudErrors(t *testing.T) {
|
||||||
|
db := setupTestDb(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
table, err := db.newTable(validRecord{})
|
||||||
|
if !assert.Nil(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
type differentRecord struct {
|
||||||
|
StringData string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass an object of the wrong type when getting a single record.
|
||||||
|
var record validRecord
|
||||||
|
err = table.getById(record.Id, record)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "input must be a ptr; got a struct", err.Error())
|
||||||
|
}
|
||||||
|
err = table.getById(record.Id, &record)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "input must be a ptr -> ptr; got a ptr -> struct", err.Error())
|
||||||
|
}
|
||||||
|
var recordTriplePointer ***validRecord
|
||||||
|
err = table.getById(record.Id, recordTriplePointer)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "input must be a ptr -> ptr -> struct; got a ptr -> ptr -> ptr", err.Error())
|
||||||
|
}
|
||||||
|
var differentRecordPointer *differentRecord
|
||||||
|
err = table.getById(record.Id, &differentRecordPointer)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(
|
||||||
|
t,
|
||||||
|
"given record of type model.differentRecord does not match expected type for table validRecord",
|
||||||
|
err.Error(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass an object of the wrong type when getting all records.
|
||||||
|
var records []validRecord
|
||||||
|
err = table.getAll(records)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "input must be a ptr; got a slice", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass an object of the wrong type when creating or updating a record.
|
||||||
|
err = table.create(record)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "input must be a ptr; got a struct", err.Error())
|
||||||
|
}
|
||||||
|
err = table.update(record)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "input must be a ptr; got a struct", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a record with a non-zero ID.
|
||||||
|
record.Id = 12345
|
||||||
|
err = table.create(&record)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "can't create validRecord with non-zero ID: 12345", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update a record with an ID of zero.
|
||||||
|
record.Id = 0
|
||||||
|
err = table.update(&record)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "can't update validRecord with zero ID", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update a nonexistent record.
|
||||||
|
record.Id = 12345
|
||||||
|
err = table.update(&record)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "can't update non-existent validRecord with ID 12345", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete a nonexistent record.
|
||||||
|
err = table.delete(12345)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "can't delete non-existent validRecord with ID 12345", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update a record with an incorrectly constructed table object.
|
||||||
|
table.idFieldIndex = nil
|
||||||
|
err = table.update(&record)
|
||||||
|
if assert.NotNil(t, err) {
|
||||||
|
assert.Equal(t, "struct validRecord has no field tagged as the id", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,6 +18,7 @@ func SetupTestDb(t *testing.T, uniqueName string) *Database {
|
|||||||
BaseDir = ".."
|
BaseDir = ".."
|
||||||
dbPath := filepath.Join(BaseDir, fmt.Sprintf("%s_test.db", uniqueName))
|
dbPath := filepath.Join(BaseDir, fmt.Sprintf("%s_test.db", uniqueName))
|
||||||
os.Remove(dbPath)
|
os.Remove(dbPath)
|
||||||
|
os.Remove(dbPath + ".bolt")
|
||||||
database, err := OpenDatabase(dbPath)
|
database, err := OpenDatabase(dbPath)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
return database
|
return database
|
||||||
|
|||||||
Reference in New Issue
Block a user