// Copyright 2021 Team 254. All Rights Reserved. // Author: pat@patfairbank.com (Patrick Fairbank) // // Defines a "table" wrapper struct and helper methods for persisting data using Bolt. package model import ( "encoding/json" "fmt" "go.etcd.io/bbolt" "reflect" "strconv" "strings" ) // Encapsulates all persistence operations for a particular data type represented by a struct. type table[R any] struct { bolt *bbolt.DB recordType reflect.Type name string bucketKey []byte idFieldIndex *int manualId bool } // Registers a new table for a struct. func newTable[R any](database *Database) (*table[R], error) { var recordType R recordTypeValue := reflect.ValueOf(recordType) if recordTypeValue.Kind() != reflect.Struct { return nil, fmt.Errorf("record type must be a struct; got %v", recordTypeValue.Kind()) } var table table[R] table.bolt = database.bolt table.recordType = reflect.TypeOf(recordType) table.name = table.recordType.Name() table.bucketKey = []byte(table.name) // Determine which field in the struct is tagged as the ID and cache its index. idFound := false for i := 0; i < recordTypeValue.Type().NumField(); i++ { field := recordTypeValue.Type().Field(i) tags := map[string]struct{}{} for _, tag := range strings.Split(field.Tag.Get("db"), ",") { tags[tag] = struct{}{} } if _, ok := tags["id"]; ok { if field.Type.Kind() != reflect.Int { return nil, fmt.Errorf( "field in struct %s tagged with 'id' must be an int; got %v", table.name, field.Type.Kind(), ) } table.idFieldIndex = new(int) *table.idFieldIndex = i idFound = true _, table.manualId = tags["manual"] break } } if !idFound { return nil, fmt.Errorf("struct %s has no field tagged as the id", table.name) } // Create the Bolt bucket corresponding to the struct. err := table.bolt.Update(func(tx *bbolt.Tx) error { _, err := tx.CreateBucketIfNotExists(table.bucketKey) return err }) if err != nil { return nil, err } return &table, nil } // Returns the record with the given ID, or nil if it doesn't exist. func (table *table[R]) getById(id int) (*R, error) { record := new(R) err := table.bolt.View(func(tx *bbolt.Tx) error { bucket, err := table.getBucket(tx) if err != nil { return err } if recordJson := bucket.Get(idToKey(id)); recordJson != nil { return json.Unmarshal(recordJson, record) } // If the record does not exist, set the record pointer to nil. record = nil return nil }) return record, err } // Returns a slice containing every record in the table, ordered by string representation of ID. func (table *table[R]) getAll() ([]R, error) { records := []R{} err := table.bolt.View(func(tx *bbolt.Tx) error { bucket, err := table.getBucket(tx) if err != nil { return err } return bucket.ForEach(func(key, value []byte) error { var record R err := json.Unmarshal(value, &record) if err != nil { return err } records = append(records, record) return nil }) }) return records, err } // Persists the given record as a new row in the table. func (table *table[R]) create(record *R) error { // Validate that the record has its ID set to zero or not as expected, depending on whether it is configured for // autogenerated IDs. value := reflect.ValueOf(record).Elem() id := int(value.Field(*table.idFieldIndex).Int()) if table.manualId && id == 0 { return fmt.Errorf("can't create %s with zero ID since table is configured for manual IDs", table.name) } else if !table.manualId && id != 0 { return fmt.Errorf( "can't create %s with non-zero ID since table is configured for autogenerated IDs: %d", table.name, id, ) } return table.bolt.Update(func(tx *bbolt.Tx) error { bucket, err := table.getBucket(tx) if err != nil { return err } if !table.manualId { // Generate a new ID for the record. newSequence, err := bucket.NextSequence() if err != nil { return err } id = int(newSequence) value.Field(*table.idFieldIndex).SetInt(int64(id)) } // Ensure that a record having the same ID does not already exist in the table. key := idToKey(id) oldRecord := bucket.Get(key) if oldRecord != nil { return fmt.Errorf("%s with ID %d already exists: %s", table.name, id, string(oldRecord)) } recordJson, err := json.Marshal(record) if err != nil { return err } return bucket.Put(key, recordJson) }) } // Persists the given record as an update to the existing row in the table. Returns an error if the record does not // already exist. func (table *table[R]) update(record *R) error { // Validate that the record has a non-zero ID. value := reflect.ValueOf(record).Elem() id := int(value.Field(*table.idFieldIndex).Int()) if id == 0 { return fmt.Errorf("can't update %s with zero ID", table.name) } return table.bolt.Update(func(tx *bbolt.Tx) error { bucket, err := table.getBucket(tx) if err != nil { return err } // Ensure that a record having the same ID exists in the table. key := idToKey(id) oldRecord := bucket.Get(key) if oldRecord == nil { return fmt.Errorf("can't update non-existent %s with ID %d", table.name, id) } recordJson, err := json.Marshal(record) if err != nil { return err } return bucket.Put(key, recordJson) }) } // Deletes the record having the given ID from the table. Returns an error if the record does not exist. func (table *table[R]) delete(id int) error { return table.bolt.Update(func(tx *bbolt.Tx) error { bucket, err := table.getBucket(tx) if err != nil { return err } // Ensure that a record having the same ID exists in the table. key := idToKey(id) oldRecord := bucket.Get(key) if oldRecord == nil { return fmt.Errorf("can't delete non-existent %s with ID %d", table.name, id) } return bucket.Delete(key) }) } // Deletes all records from the table. func (table *table[R]) truncate() error { return table.bolt.Update(func(tx *bbolt.Tx) error { _, err := table.getBucket(tx) if err != nil { return err } // Carry out the truncation by way of deleting the whole bucket and then recreate it. err = tx.DeleteBucket(table.bucketKey) if err != nil { return err } _, err = tx.CreateBucket(table.bucketKey) return err }) } // Obtains the Bolt bucket belonging to the table. func (table *table[R]) getBucket(tx *bbolt.Tx) (*bbolt.Bucket, error) { bucket := tx.Bucket(table.bucketKey) if bucket == nil { return nil, fmt.Errorf("unknown table %s", table.name) } return bucket, nil } // Serializes the given integer ID to a byte array containing its Base-10 string representation. func idToKey(id int) []byte { return []byte(strconv.Itoa(id)) }