mirror of
https://github.com/Team254/cheesy-arena-lite.git
synced 2026-03-09 13:46:44 -04:00
Add table wrapper struct for using Bolt DB instead of SQLite.
This commit is contained in:
@@ -12,7 +12,9 @@ import (
|
||||
"fmt"
|
||||
"github.com/jmoiron/modl"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"go.etcd.io/bbolt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -23,6 +25,7 @@ const backupsDir = "db/backups"
|
||||
const migrationsDir = "db/migrations"
|
||||
|
||||
var BaseDir = "." // Mutable for testing
|
||||
var recordTypes = []interface{}{}
|
||||
|
||||
type Database struct {
|
||||
Path string
|
||||
@@ -38,6 +41,8 @@ type Database struct {
|
||||
scheduleBlockMap *modl.DbMap
|
||||
awardMap *modl.DbMap
|
||||
userSessionMap *modl.DbMap
|
||||
bolt *bbolt.DB
|
||||
tables map[interface{}]*table
|
||||
}
|
||||
|
||||
// Opens the SQLite database at the given path, creating it if it doesn't exist, and runs any pending
|
||||
@@ -64,11 +69,29 @@ func OpenDatabase(filename string) (*Database, error) {
|
||||
database.db = db
|
||||
database.mapTables()
|
||||
|
||||
database.bolt, err = bbolt.Open(database.Path+".bolt", 0644, &bbolt.Options{NoSync: true, Timeout: time.Second})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register tables.
|
||||
database.tables = make(map[interface{}]*table)
|
||||
for _, recordType := range recordTypes {
|
||||
table, err := database.newTable(recordType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
database.tables[recordType] = table
|
||||
}
|
||||
|
||||
return &database, nil
|
||||
}
|
||||
|
||||
func (database *Database) Close() {
|
||||
database.db.Close()
|
||||
if err := database.bolt.Close(); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a copy of the current database and saves it to the backups directory.
|
||||
|
||||
282
model/table.go
Normal file
282
model/table.go
Normal file
@@ -0,0 +1,282 @@
|
||||
// Copyright 2021 Team 254. All Rights Reserved.
|
||||
// Author: pat@patfairbank.com (Patrick Fairbank)
|
||||
//
|
||||
// Defines a "table" wrapper struct and helper methods for persisting data using Bolt.
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"go.etcd.io/bbolt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Encapsulates all persistence operations for a particular data type represented by a struct.
|
||||
type table struct {
|
||||
bolt *bbolt.DB
|
||||
recordType reflect.Type
|
||||
name string
|
||||
bucketKey []byte
|
||||
idFieldIndex *int
|
||||
}
|
||||
|
||||
// Registers a new table for a struct, given its zero value.
|
||||
func (database *Database) newTable(recordType interface{}) (*table, error) {
|
||||
recordTypeValue := reflect.ValueOf(recordType)
|
||||
if recordTypeValue.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("record type must be a struct; got %v", recordTypeValue.Kind())
|
||||
}
|
||||
|
||||
var table table
|
||||
table.bolt = database.bolt
|
||||
table.recordType = reflect.TypeOf(recordType)
|
||||
table.name = table.recordType.Name()
|
||||
table.bucketKey = []byte(table.name)
|
||||
|
||||
// Determine which field in the struct is tagged as the ID and cache its index.
|
||||
idFound := false
|
||||
for i := 0; i < recordTypeValue.Type().NumField(); i++ {
|
||||
field := recordTypeValue.Type().Field(i)
|
||||
tag := field.Tag.Get("db")
|
||||
if tag == "id" {
|
||||
if field.Type.Kind() != reflect.Int64 {
|
||||
return nil,
|
||||
fmt.Errorf(
|
||||
"field in struct %s tagged with 'id' must be an int64; got %v", table.name, field.Type.Kind(),
|
||||
)
|
||||
}
|
||||
table.idFieldIndex = new(int)
|
||||
*table.idFieldIndex = i
|
||||
idFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !idFound {
|
||||
return nil, fmt.Errorf("struct %s has no field tagged as the id", table.name)
|
||||
}
|
||||
|
||||
// Create the Bolt bucket corresponding to the struct.
|
||||
err := table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists(table.bucketKey)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &table, nil
|
||||
}
|
||||
|
||||
// Populates the given double pointer to a record with the data from the record with the given ID, or nil if it doesn't
|
||||
// exist.
|
||||
func (table *table) getById(id int64, record interface{}) error {
|
||||
if err := table.validateType(record, reflect.Ptr, reflect.Ptr, reflect.Struct); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return table.bolt.View(func(tx *bbolt.Tx) error {
|
||||
bucket, err := table.getBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if recordJson := bucket.Get(idToKey(id)); recordJson != nil {
|
||||
return json.Unmarshal(recordJson, record)
|
||||
}
|
||||
|
||||
// If the record does not exist, set the record pointer to nil.
|
||||
recordPointerValue := reflect.ValueOf(record).Elem()
|
||||
recordPointerValue.Set(reflect.Zero(recordPointerValue.Type()))
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Populates the given slice passed by pointer with the data from every record in the table, ordered by ID.
|
||||
func (table *table) getAll(recordSlice interface{}) error {
|
||||
if err := table.validateType(recordSlice, reflect.Ptr, reflect.Slice, reflect.Struct); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return table.bolt.View(func(tx *bbolt.Tx) error {
|
||||
bucket, err := table.getBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordSliceValue := reflect.ValueOf(recordSlice).Elem()
|
||||
recordSliceValue.Set(reflect.MakeSlice(recordSliceValue.Type(), 0, 0))
|
||||
return bucket.ForEach(func(key, value []byte) error {
|
||||
record := reflect.New(table.recordType)
|
||||
err := json.Unmarshal(value, record.Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
recordSliceValue.Set(reflect.Append(recordSliceValue, record.Elem()))
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Persists the given record as a new row in the table.
|
||||
func (table *table) create(record interface{}) error {
|
||||
if err := table.validateType(record, reflect.Ptr, reflect.Struct); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate that the record has its ID set to zero since it will be given an auto-generated one.
|
||||
value := reflect.ValueOf(record).Elem()
|
||||
id := value.Field(*table.idFieldIndex).Int()
|
||||
if id != 0 {
|
||||
return fmt.Errorf("can't create %s with non-zero ID: %d", table.name, id)
|
||||
}
|
||||
|
||||
return table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||
bucket, err := table.getBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate a new ID for the record.
|
||||
newSequence, err := bucket.NextSequence()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
id = int64(newSequence)
|
||||
value.Field(*table.idFieldIndex).SetInt(id)
|
||||
|
||||
// Ensure that a record having the same ID does not already exist in the table.
|
||||
key := idToKey(id)
|
||||
oldRecord := bucket.Get(key)
|
||||
if oldRecord != nil {
|
||||
return fmt.Errorf("%s with ID %d already exists: %s", table.name, id, string(oldRecord))
|
||||
}
|
||||
|
||||
recordJson, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bucket.Put(key, recordJson)
|
||||
})
|
||||
}
|
||||
|
||||
// Persists the given record as an update to the existing row in the table. Returns an error if the record does not
|
||||
// already exist.
|
||||
func (table *table) update(record interface{}) error {
|
||||
if err := table.validateType(record, reflect.Ptr, reflect.Struct); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate that the record has a non-zero ID.
|
||||
value := reflect.ValueOf(record).Elem()
|
||||
id := value.Field(*table.idFieldIndex).Int()
|
||||
if id == 0 {
|
||||
return fmt.Errorf("can't update %s with zero ID", table.name)
|
||||
}
|
||||
|
||||
return table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||
bucket, err := table.getBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that a record having the same ID exists in the table.
|
||||
key := idToKey(id)
|
||||
oldRecord := bucket.Get(key)
|
||||
if oldRecord == nil {
|
||||
return fmt.Errorf("can't update non-existent %s with ID %d", table.name, id)
|
||||
}
|
||||
|
||||
recordJson, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bucket.Put(key, recordJson)
|
||||
})
|
||||
}
|
||||
|
||||
// Deletes the record having the given ID from the table. Returns an error if the record does not exist.
|
||||
func (table *table) delete(id int64) error {
|
||||
return table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||
bucket, err := table.getBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that a record having the same ID exists in the table.
|
||||
key := idToKey(id)
|
||||
oldRecord := bucket.Get(key)
|
||||
if oldRecord == nil {
|
||||
return fmt.Errorf("can't delete non-existent %s with ID %d", table.name, id)
|
||||
}
|
||||
|
||||
return bucket.Delete(key)
|
||||
})
|
||||
}
|
||||
|
||||
// Deletes all records from the table.
|
||||
func (table *table) truncate() error {
|
||||
return table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||
_, err := table.getBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Carry out the truncation by way of deleting the whole bucket and then recreate it.
|
||||
err = tx.DeleteBucket(table.bucketKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.CreateBucket(table.bucketKey)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// Obtains the Bolt bucket belonging to the table.
|
||||
func (table *table) getBucket(tx *bbolt.Tx) (*bbolt.Bucket, error) {
|
||||
bucket := tx.Bucket(table.bucketKey)
|
||||
if bucket == nil {
|
||||
return nil, fmt.Errorf("unknown table %s", table.name)
|
||||
}
|
||||
return bucket, nil
|
||||
}
|
||||
|
||||
// Validates that the given record is of the expected derived type (e.g. pointer, slice, etc.), that the base type is
|
||||
// the same as that stored in the table, and that the table is configured correctly.
|
||||
func (table *table) validateType(record interface{}, kinds ...reflect.Kind) error {
|
||||
// Check the hierarchy of kinds against the expected list until reaching the base record type.
|
||||
recordType := reflect.ValueOf(record).Type()
|
||||
expectedKind := ""
|
||||
actualKind := ""
|
||||
for i, kind := range kinds {
|
||||
if i > 0 {
|
||||
expectedKind += " -> "
|
||||
actualKind += " -> "
|
||||
}
|
||||
expectedKind += kind.String()
|
||||
actualKind += recordType.Kind().String()
|
||||
if recordType.Kind() != kind {
|
||||
return fmt.Errorf("input must be a %s; got a %s", expectedKind, actualKind)
|
||||
}
|
||||
if i < len(kinds)-1 {
|
||||
recordType = recordType.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
if recordType != table.recordType {
|
||||
return fmt.Errorf("given record of type %s does not match expected type for table %s", recordType, table.name)
|
||||
}
|
||||
|
||||
if table.idFieldIndex == nil {
|
||||
return fmt.Errorf("struct %s has no field tagged as the id", table.name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serializes the given integer ID to a byte array containing its Base-10 string representation.
|
||||
func idToKey(id int64) []byte {
|
||||
return []byte(strconv.FormatInt(id, 10))
|
||||
}
|
||||
208
model/table_test.go
Normal file
208
model/table_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
// Copyright 2021 Team 254. All Rights Reserved.
|
||||
// Author: pat@patfairbank.com (Patrick Fairbank)
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type validRecord struct {
|
||||
Id int64 `db:"id"`
|
||||
IntData int
|
||||
StringData string
|
||||
}
|
||||
|
||||
func TestTableSingleCrud(t *testing.T) {
|
||||
db := setupTestDb(t)
|
||||
defer db.Close()
|
||||
|
||||
table, err := db.newTable(validRecord{})
|
||||
if !assert.Nil(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test initial create and then read back.
|
||||
record := validRecord{IntData: 254, StringData: "The Cheesy Poofs"}
|
||||
if assert.Nil(t, table.create(&record)) {
|
||||
assert.Equal(t, int64(1), record.Id)
|
||||
}
|
||||
var record2 *validRecord
|
||||
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||
assert.Equal(t, record, *record2)
|
||||
}
|
||||
|
||||
// Test update and then read back.
|
||||
record.IntData = 252
|
||||
record.StringData = "Teh Chezy Pofs"
|
||||
assert.Nil(t, table.update(&record))
|
||||
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||
assert.Equal(t, record, *record2)
|
||||
}
|
||||
|
||||
// Test delete.
|
||||
assert.Nil(t, table.delete(record.Id))
|
||||
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||
assert.Nil(t, record2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTableMultipleCrud(t *testing.T) {
|
||||
db := setupTestDb(t)
|
||||
defer db.Close()
|
||||
|
||||
table, err := db.newTable(validRecord{})
|
||||
if !assert.Nil(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
// Insert a few test records.
|
||||
record1 := validRecord{IntData: 1, StringData: "One"}
|
||||
record2 := validRecord{IntData: 2, StringData: "Two"}
|
||||
record3 := validRecord{IntData: 3, StringData: "Three"}
|
||||
assert.Nil(t, table.create(&record1))
|
||||
assert.Nil(t, table.create(&record2))
|
||||
assert.Nil(t, table.create(&record3))
|
||||
|
||||
// Read all records.
|
||||
var records []validRecord
|
||||
assert.Nil(t, table.getAll(&records))
|
||||
if assert.Equal(t, 3, len(records)) {
|
||||
assert.Equal(t, record1, records[0])
|
||||
assert.Equal(t, record2, records[1])
|
||||
assert.Equal(t, record3, records[2])
|
||||
}
|
||||
|
||||
// Truncate the table and verify that the records no longer exist.
|
||||
assert.Nil(t, table.truncate())
|
||||
assert.Nil(t, table.getAll(&records))
|
||||
assert.Equal(t, 0, len(records))
|
||||
var record4 *validRecord
|
||||
if assert.Nil(t, table.getById(record1.Id, &record4)) {
|
||||
assert.Nil(t, record4)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTableErrors(t *testing.T) {
|
||||
db := setupTestDb(t)
|
||||
defer db.Close()
|
||||
|
||||
// Pass a non-struct as the record type.
|
||||
table, err := db.newTable(123)
|
||||
assert.Nil(t, table)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "record type must be a struct; got int", err.Error())
|
||||
}
|
||||
|
||||
// Pass a struct that doesn't have an ID field.
|
||||
type recordWithNoId struct {
|
||||
StringData string
|
||||
}
|
||||
table, err = db.newTable(recordWithNoId{})
|
||||
assert.Nil(t, table)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "struct recordWithNoId has no field tagged as the id", err.Error())
|
||||
}
|
||||
|
||||
// Pass a struct that has a field with the wrong type tagged as the ID.
|
||||
type recordWithWrongIdType struct {
|
||||
Id bool `db:"id"`
|
||||
}
|
||||
table, err = db.newTable(recordWithWrongIdType{})
|
||||
assert.Nil(t, table)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(
|
||||
t, "field in struct recordWithWrongIdType tagged with 'id' must be an int64; got bool", err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTableCrudErrors(t *testing.T) {
|
||||
db := setupTestDb(t)
|
||||
defer db.Close()
|
||||
|
||||
table, err := db.newTable(validRecord{})
|
||||
if !assert.Nil(t, err) {
|
||||
return
|
||||
}
|
||||
type differentRecord struct {
|
||||
StringData string
|
||||
}
|
||||
|
||||
// Pass an object of the wrong type when getting a single record.
|
||||
var record validRecord
|
||||
err = table.getById(record.Id, record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr; got a struct", err.Error())
|
||||
}
|
||||
err = table.getById(record.Id, &record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr -> ptr; got a ptr -> struct", err.Error())
|
||||
}
|
||||
var recordTriplePointer ***validRecord
|
||||
err = table.getById(record.Id, recordTriplePointer)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr -> ptr -> struct; got a ptr -> ptr -> ptr", err.Error())
|
||||
}
|
||||
var differentRecordPointer *differentRecord
|
||||
err = table.getById(record.Id, &differentRecordPointer)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(
|
||||
t,
|
||||
"given record of type model.differentRecord does not match expected type for table validRecord",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
// Pass an object of the wrong type when getting all records.
|
||||
var records []validRecord
|
||||
err = table.getAll(records)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr; got a slice", err.Error())
|
||||
}
|
||||
|
||||
// Pass an object of the wrong type when creating or updating a record.
|
||||
err = table.create(record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr; got a struct", err.Error())
|
||||
}
|
||||
err = table.update(record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr; got a struct", err.Error())
|
||||
}
|
||||
|
||||
// Create a record with a non-zero ID.
|
||||
record.Id = 12345
|
||||
err = table.create(&record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "can't create validRecord with non-zero ID: 12345", err.Error())
|
||||
}
|
||||
|
||||
// Update a record with an ID of zero.
|
||||
record.Id = 0
|
||||
err = table.update(&record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "can't update validRecord with zero ID", err.Error())
|
||||
}
|
||||
|
||||
// Update a nonexistent record.
|
||||
record.Id = 12345
|
||||
err = table.update(&record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "can't update non-existent validRecord with ID 12345", err.Error())
|
||||
}
|
||||
|
||||
// Delete a nonexistent record.
|
||||
err = table.delete(12345)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "can't delete non-existent validRecord with ID 12345", err.Error())
|
||||
}
|
||||
|
||||
// Update a record with an incorrectly constructed table object.
|
||||
table.idFieldIndex = nil
|
||||
err = table.update(&record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "struct validRecord has no field tagged as the id", err.Error())
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ func SetupTestDb(t *testing.T, uniqueName string) *Database {
|
||||
BaseDir = ".."
|
||||
dbPath := filepath.Join(BaseDir, fmt.Sprintf("%s_test.db", uniqueName))
|
||||
os.Remove(dbPath)
|
||||
os.Remove(dbPath + ".bolt")
|
||||
database, err := OpenDatabase(dbPath)
|
||||
assert.Nil(t, err)
|
||||
return database
|
||||
|
||||
Reference in New Issue
Block a user