mirror of
https://github.com/Team254/cheesy-arena-lite.git
synced 2026-03-10 06:06:47 -04:00
Use Go generics to reduce complexity of DB code.
This commit is contained in:
11
go.mod
11
go.mod
@@ -1,11 +1,10 @@
|
||||
module github.com/Team254/cheesy-arena-lite
|
||||
|
||||
go 1.16
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/dchest/uniuri v0.0.0-20200228104902-7aecb25e1fe5
|
||||
github.com/goburrow/modbus v0.1.0
|
||||
github.com/goburrow/serial v0.1.0 // indirect
|
||||
github.com/google/uuid v1.2.0
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
@@ -15,3 +14,11 @@ require (
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.0 // indirect
|
||||
github.com/goburrow/serial v0.1.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
|
||||
)
|
||||
|
||||
@@ -19,8 +19,8 @@ func (database *Database) CreateAllianceTeam(allianceTeam *AllianceTeam) error {
|
||||
}
|
||||
|
||||
func (database *Database) GetTeamsByAlliance(allianceId int) ([]AllianceTeam, error) {
|
||||
var allianceTeams []AllianceTeam
|
||||
if err := database.allianceTeamTable.getAll(&allianceTeams); err != nil {
|
||||
allianceTeams, err := database.allianceTeamTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(allianceTeams, func(i, j int) bool {
|
||||
@@ -49,8 +49,8 @@ func (database *Database) TruncateAllianceTeams() error {
|
||||
}
|
||||
|
||||
func (database *Database) GetAllAlliances() ([][]AllianceTeam, error) {
|
||||
var allianceTeams []AllianceTeam
|
||||
if err := database.allianceTeamTable.getAll(&allianceTeams); err != nil {
|
||||
allianceTeams, err := database.allianceTeamTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(allianceTeams, func(i, j int) bool {
|
||||
|
||||
@@ -28,9 +28,7 @@ func (database *Database) CreateAward(award *Award) error {
|
||||
}
|
||||
|
||||
func (database *Database) GetAwardById(id int) (*Award, error) {
|
||||
var award *Award
|
||||
err := database.awardTable.getById(id, &award)
|
||||
return award, err
|
||||
return database.awardTable.getById(id)
|
||||
}
|
||||
|
||||
func (database *Database) UpdateAward(award *Award) error {
|
||||
@@ -46,8 +44,8 @@ func (database *Database) TruncateAwards() error {
|
||||
}
|
||||
|
||||
func (database *Database) GetAllAwards() ([]Award, error) {
|
||||
var awards []Award
|
||||
if err := database.awardTable.getAll(&awards); err != nil {
|
||||
awards, err := database.awardTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(awards, func(i, j int) bool {
|
||||
|
||||
@@ -23,17 +23,17 @@ var BaseDir = "." // Mutable for testing
|
||||
type Database struct {
|
||||
Path string
|
||||
bolt *bbolt.DB
|
||||
allianceTeamTable *table
|
||||
awardTable *table
|
||||
eventSettingsTable *table
|
||||
lowerThirdTable *table
|
||||
matchTable *table
|
||||
matchResultTable *table
|
||||
rankingTable *table
|
||||
scheduleBlockTable *table
|
||||
sponsorSlideTable *table
|
||||
teamTable *table
|
||||
userSessionTable *table
|
||||
allianceTeamTable *table[AllianceTeam]
|
||||
awardTable *table[Award]
|
||||
eventSettingsTable *table[EventSettings]
|
||||
lowerThirdTable *table[LowerThird]
|
||||
matchTable *table[Match]
|
||||
matchResultTable *table[MatchResult]
|
||||
rankingTable *table[game.Ranking]
|
||||
scheduleBlockTable *table[ScheduleBlock]
|
||||
sponsorSlideTable *table[SponsorSlide]
|
||||
teamTable *table[Team]
|
||||
userSessionTable *table[UserSession]
|
||||
}
|
||||
|
||||
// Opens the Bolt database at the given path, creating it if it doesn't exist.
|
||||
@@ -46,37 +46,37 @@ func OpenDatabase(filename string) (*Database, error) {
|
||||
}
|
||||
|
||||
// Register tables.
|
||||
if database.allianceTeamTable, err = database.newTable(AllianceTeam{}); err != nil {
|
||||
if database.allianceTeamTable, err = newTable[AllianceTeam](&database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if database.awardTable, err = database.newTable(Award{}); err != nil {
|
||||
if database.awardTable, err = newTable[Award](&database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if database.eventSettingsTable, err = database.newTable(EventSettings{}); err != nil {
|
||||
if database.eventSettingsTable, err = newTable[EventSettings](&database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if database.lowerThirdTable, err = database.newTable(LowerThird{}); err != nil {
|
||||
if database.lowerThirdTable, err = newTable[LowerThird](&database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if database.matchTable, err = database.newTable(Match{}); err != nil {
|
||||
if database.matchTable, err = newTable[Match](&database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if database.matchResultTable, err = database.newTable(MatchResult{}); err != nil {
|
||||
if database.matchResultTable, err = newTable[MatchResult](&database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if database.rankingTable, err = database.newTable(game.Ranking{}); err != nil {
|
||||
if database.rankingTable, err = newTable[game.Ranking](&database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if database.scheduleBlockTable, err = database.newTable(ScheduleBlock{}); err != nil {
|
||||
if database.scheduleBlockTable, err = newTable[ScheduleBlock](&database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if database.sponsorSlideTable, err = database.newTable(SponsorSlide{}); err != nil {
|
||||
if database.sponsorSlideTable, err = newTable[SponsorSlide](&database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if database.teamTable, err = database.newTable(Team{}); err != nil {
|
||||
if database.teamTable, err = newTable[Team](&database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if database.userSessionTable, err = database.newTable(UserSession{}); err != nil {
|
||||
if database.userSessionTable, err = newTable[UserSession](&database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -41,8 +41,8 @@ type EventSettings struct {
|
||||
}
|
||||
|
||||
func (database *Database) GetEventSettings() (*EventSettings, error) {
|
||||
var allEventSettings []EventSettings
|
||||
if err := database.eventSettingsTable.getAll(&allEventSettings); err != nil {
|
||||
allEventSettings, err := database.eventSettingsTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(allEventSettings) == 1 {
|
||||
|
||||
@@ -22,9 +22,7 @@ func (database *Database) CreateLowerThird(lowerThird *LowerThird) error {
|
||||
}
|
||||
|
||||
func (database *Database) GetLowerThirdById(id int) (*LowerThird, error) {
|
||||
var lowerThird *LowerThird
|
||||
err := database.lowerThirdTable.getById(id, &lowerThird)
|
||||
return lowerThird, err
|
||||
return database.lowerThirdTable.getById(id)
|
||||
}
|
||||
|
||||
func (database *Database) UpdateLowerThird(lowerThird *LowerThird) error {
|
||||
@@ -40,8 +38,8 @@ func (database *Database) TruncateLowerThirds() error {
|
||||
}
|
||||
|
||||
func (database *Database) GetAllLowerThirds() ([]LowerThird, error) {
|
||||
var lowerThirds []LowerThird
|
||||
if err := database.lowerThirdTable.getAll(&lowerThirds); err != nil {
|
||||
lowerThirds, err := database.lowerThirdTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(lowerThirds, func(i, j int) bool {
|
||||
|
||||
@@ -55,9 +55,7 @@ func (database *Database) CreateMatch(match *Match) error {
|
||||
}
|
||||
|
||||
func (database *Database) GetMatchById(id int) (*Match, error) {
|
||||
var match *Match
|
||||
err := database.matchTable.getById(id, &match)
|
||||
return match, err
|
||||
return database.matchTable.getById(id)
|
||||
}
|
||||
|
||||
func (database *Database) UpdateMatch(match *Match) error {
|
||||
@@ -73,8 +71,8 @@ func (database *Database) TruncateMatches() error {
|
||||
}
|
||||
|
||||
func (database *Database) GetMatchByName(matchType string, displayName string) (*Match, error) {
|
||||
var matches []Match
|
||||
if err := database.matchTable.getAll(&matches); err != nil {
|
||||
matches, err := database.matchTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -102,8 +100,8 @@ func (database *Database) GetMatchesByElimRoundGroup(round int, group int) ([]Ma
|
||||
}
|
||||
|
||||
func (database *Database) GetMatchesByType(matchType string) ([]Match, error) {
|
||||
var matches []Match
|
||||
if err := database.matchTable.getAll(&matches); err != nil {
|
||||
matches, err := database.matchTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -31,8 +31,8 @@ func (database *Database) CreateMatchResult(matchResult *MatchResult) error {
|
||||
}
|
||||
|
||||
func (database *Database) GetMatchResultForMatch(matchId int) (*MatchResult, error) {
|
||||
var matchResults []MatchResult
|
||||
if err := database.matchResultTable.getAll(&matchResults); err != nil {
|
||||
matchResults, err := database.matchResultTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -15,9 +15,7 @@ func (database *Database) CreateRanking(ranking *game.Ranking) error {
|
||||
}
|
||||
|
||||
func (database *Database) GetRankingForTeam(teamId int) (*game.Ranking, error) {
|
||||
var ranking *game.Ranking
|
||||
err := database.rankingTable.getById(teamId, &ranking)
|
||||
return ranking, err
|
||||
return database.rankingTable.getById(teamId)
|
||||
}
|
||||
|
||||
func (database *Database) UpdateRanking(ranking *game.Ranking) error {
|
||||
@@ -33,8 +31,8 @@ func (database *Database) TruncateRankings() error {
|
||||
}
|
||||
|
||||
func (database *Database) GetAllRankings() (game.Rankings, error) {
|
||||
var rankings []game.Ranking
|
||||
if err := database.rankingTable.getAll(&rankings); err != nil {
|
||||
rankings, err := database.rankingTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(rankings, func(i, j int) bool {
|
||||
|
||||
@@ -23,8 +23,8 @@ func (database *Database) CreateScheduleBlock(block *ScheduleBlock) error {
|
||||
}
|
||||
|
||||
func (database *Database) GetScheduleBlocksByMatchType(matchType string) ([]ScheduleBlock, error) {
|
||||
var scheduleBlocks []ScheduleBlock
|
||||
if err := database.scheduleBlockTable.getAll(&scheduleBlocks); err != nil {
|
||||
scheduleBlocks, err := database.scheduleBlockTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -22,9 +22,7 @@ func (database *Database) CreateSponsorSlide(sponsorSlide *SponsorSlide) error {
|
||||
}
|
||||
|
||||
func (database *Database) GetSponsorSlideById(id int) (*SponsorSlide, error) {
|
||||
var sponsorSlide *SponsorSlide
|
||||
err := database.sponsorSlideTable.getById(id, &sponsorSlide)
|
||||
return sponsorSlide, err
|
||||
return database.sponsorSlideTable.getById(id)
|
||||
}
|
||||
|
||||
func (database *Database) UpdateSponsorSlide(sponsorSlide *SponsorSlide) error {
|
||||
@@ -40,8 +38,8 @@ func (database *Database) TruncateSponsorSlides() error {
|
||||
}
|
||||
|
||||
func (database *Database) GetAllSponsorSlides() ([]SponsorSlide, error) {
|
||||
var sponsorSlides []SponsorSlide
|
||||
if err := database.sponsorSlideTable.getAll(&sponsorSlides); err != nil {
|
||||
sponsorSlides, err := database.sponsorSlideTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(sponsorSlides, func(i, j int) bool {
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
// Encapsulates all persistence operations for a particular data type represented by a struct.
|
||||
type table struct {
|
||||
type table[R any] struct {
|
||||
bolt *bbolt.DB
|
||||
recordType reflect.Type
|
||||
name string
|
||||
@@ -24,14 +24,15 @@ type table struct {
|
||||
manualId bool
|
||||
}
|
||||
|
||||
// Registers a new table for a struct, given its zero value.
|
||||
func (database *Database) newTable(recordType interface{}) (*table, error) {
|
||||
// Registers a new table for a struct.
|
||||
func newTable[R any](database *Database) (*table[R], error) {
|
||||
var recordType R
|
||||
recordTypeValue := reflect.ValueOf(recordType)
|
||||
if recordTypeValue.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("record type must be a struct; got %v", recordTypeValue.Kind())
|
||||
}
|
||||
|
||||
var table table
|
||||
var table table[R]
|
||||
table.bolt = database.bolt
|
||||
table.recordType = reflect.TypeOf(recordType)
|
||||
table.name = table.recordType.Name()
|
||||
@@ -75,14 +76,10 @@ func (database *Database) newTable(recordType interface{}) (*table, error) {
|
||||
return &table, nil
|
||||
}
|
||||
|
||||
// Populates the given double pointer to a record with the data from the record with the given ID, or nil if it doesn't
|
||||
// exist.
|
||||
func (table *table) getById(id int, record interface{}) error {
|
||||
if err := table.validateType(record, reflect.Ptr, reflect.Ptr, reflect.Struct); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return table.bolt.View(func(tx *bbolt.Tx) error {
|
||||
// Returns the record with the given ID, or nil if it doesn't exist.
|
||||
func (table *table[R]) getById(id int) (*R, error) {
|
||||
record := new(R)
|
||||
err := table.bolt.View(func(tx *bbolt.Tx) error {
|
||||
bucket, err := table.getBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -93,45 +90,36 @@ func (table *table) getById(id int, record interface{}) error {
|
||||
}
|
||||
|
||||
// If the record does not exist, set the record pointer to nil.
|
||||
recordPointerValue := reflect.ValueOf(record).Elem()
|
||||
recordPointerValue.Set(reflect.Zero(recordPointerValue.Type()))
|
||||
|
||||
record = nil
|
||||
return nil
|
||||
})
|
||||
return record, err
|
||||
}
|
||||
|
||||
// Populates the given slice passed by pointer with the data from every record in the table, ordered by ID.
|
||||
func (table *table) getAll(recordSlice interface{}) error {
|
||||
if err := table.validateType(recordSlice, reflect.Ptr, reflect.Slice, reflect.Struct); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return table.bolt.View(func(tx *bbolt.Tx) error {
|
||||
// Returns a slice containing every record in the table, ordered by string representation of ID.
|
||||
func (table *table[R]) getAll() ([]R, error) {
|
||||
records := []R{}
|
||||
err := table.bolt.View(func(tx *bbolt.Tx) error {
|
||||
bucket, err := table.getBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordSliceValue := reflect.ValueOf(recordSlice).Elem()
|
||||
recordSliceValue.Set(reflect.MakeSlice(recordSliceValue.Type(), 0, 0))
|
||||
return bucket.ForEach(func(key, value []byte) error {
|
||||
record := reflect.New(table.recordType)
|
||||
err := json.Unmarshal(value, record.Interface())
|
||||
var record R
|
||||
err := json.Unmarshal(value, &record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
recordSliceValue.Set(reflect.Append(recordSliceValue, record.Elem()))
|
||||
records = append(records, record)
|
||||
return nil
|
||||
})
|
||||
})
|
||||
return records, err
|
||||
}
|
||||
|
||||
// Persists the given record as a new row in the table.
|
||||
func (table *table) create(record interface{}) error {
|
||||
if err := table.validateType(record, reflect.Ptr, reflect.Struct); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (table *table[R]) create(record *R) error {
|
||||
// Validate that the record has its ID set to zero or not as expected, depending on whether it is configured for
|
||||
// autogenerated IDs.
|
||||
value := reflect.ValueOf(record).Elem()
|
||||
@@ -177,11 +165,7 @@ func (table *table) create(record interface{}) error {
|
||||
|
||||
// Persists the given record as an update to the existing row in the table. Returns an error if the record does not
|
||||
// already exist.
|
||||
func (table *table) update(record interface{}) error {
|
||||
if err := table.validateType(record, reflect.Ptr, reflect.Struct); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (table *table[R]) update(record *R) error {
|
||||
// Validate that the record has a non-zero ID.
|
||||
value := reflect.ValueOf(record).Elem()
|
||||
id := int(value.Field(*table.idFieldIndex).Int())
|
||||
@@ -211,7 +195,7 @@ func (table *table) update(record interface{}) error {
|
||||
}
|
||||
|
||||
// Deletes the record having the given ID from the table. Returns an error if the record does not exist.
|
||||
func (table *table) delete(id int) error {
|
||||
func (table *table[R]) delete(id int) error {
|
||||
return table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||
bucket, err := table.getBucket(tx)
|
||||
if err != nil {
|
||||
@@ -230,7 +214,7 @@ func (table *table) delete(id int) error {
|
||||
}
|
||||
|
||||
// Deletes all records from the table.
|
||||
func (table *table) truncate() error {
|
||||
func (table *table[R]) truncate() error {
|
||||
return table.bolt.Update(func(tx *bbolt.Tx) error {
|
||||
_, err := table.getBucket(tx)
|
||||
if err != nil {
|
||||
@@ -248,7 +232,7 @@ func (table *table) truncate() error {
|
||||
}
|
||||
|
||||
// Obtains the Bolt bucket belonging to the table.
|
||||
func (table *table) getBucket(tx *bbolt.Tx) (*bbolt.Bucket, error) {
|
||||
func (table *table[R]) getBucket(tx *bbolt.Tx) (*bbolt.Bucket, error) {
|
||||
bucket := tx.Bucket(table.bucketKey)
|
||||
if bucket == nil {
|
||||
return nil, fmt.Errorf("unknown table %s", table.name)
|
||||
@@ -256,39 +240,6 @@ func (table *table) getBucket(tx *bbolt.Tx) (*bbolt.Bucket, error) {
|
||||
return bucket, nil
|
||||
}
|
||||
|
||||
// Validates that the given record is of the expected derived type (e.g. pointer, slice, etc.), that the base type is
|
||||
// the same as that stored in the table, and that the table is configured correctly.
|
||||
func (table *table) validateType(record interface{}, kinds ...reflect.Kind) error {
|
||||
// Check the hierarchy of kinds against the expected list until reaching the base record type.
|
||||
recordType := reflect.ValueOf(record).Type()
|
||||
expectedKind := ""
|
||||
actualKind := ""
|
||||
for i, kind := range kinds {
|
||||
if i > 0 {
|
||||
expectedKind += " -> "
|
||||
actualKind += " -> "
|
||||
}
|
||||
expectedKind += kind.String()
|
||||
actualKind += recordType.Kind().String()
|
||||
if recordType.Kind() != kind {
|
||||
return fmt.Errorf("input must be a %s; got a %s", expectedKind, actualKind)
|
||||
}
|
||||
if i < len(kinds)-1 {
|
||||
recordType = recordType.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
if recordType != table.recordType {
|
||||
return fmt.Errorf("given record of type %s does not match expected type for table %s", recordType, table.name)
|
||||
}
|
||||
|
||||
if table.idFieldIndex == nil {
|
||||
return fmt.Errorf("struct %s has no field tagged as the id", table.name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serializes the given integer ID to a byte array containing its Base-10 string representation.
|
||||
func idToKey(id int) []byte {
|
||||
return []byte(strconv.Itoa(id))
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestTableSingleCrud(t *testing.T) {
|
||||
db := setupTestDb(t)
|
||||
defer db.Close()
|
||||
|
||||
table, err := db.newTable(validRecord{})
|
||||
table, err := newTable[validRecord](db)
|
||||
if !assert.Nil(t, err) {
|
||||
return
|
||||
}
|
||||
@@ -33,31 +33,30 @@ func TestTableSingleCrud(t *testing.T) {
|
||||
if assert.Nil(t, table.create(&record)) {
|
||||
assert.Equal(t, 1, record.Id)
|
||||
}
|
||||
var record2 *validRecord
|
||||
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||
assert.Equal(t, record, *record2)
|
||||
}
|
||||
record2, err := table.getById(record.Id)
|
||||
assert.Equal(t, record, *record2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Test update and then read back.
|
||||
record.IntData = 252
|
||||
record.StringData = "Teh Chezy Pofs"
|
||||
assert.Nil(t, table.update(&record))
|
||||
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||
assert.Equal(t, record, *record2)
|
||||
}
|
||||
record2, err = table.getById(record.Id)
|
||||
assert.Equal(t, record, *record2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Test delete.
|
||||
assert.Nil(t, table.delete(record.Id))
|
||||
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||
assert.Nil(t, record2)
|
||||
}
|
||||
record2, err = table.getById(record.Id)
|
||||
assert.Nil(t, record2)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestTableMultipleCrud(t *testing.T) {
|
||||
db := setupTestDb(t)
|
||||
defer db.Close()
|
||||
|
||||
table, err := db.newTable(validRecord{})
|
||||
table, err := newTable[validRecord](db)
|
||||
if !assert.Nil(t, err) {
|
||||
return
|
||||
}
|
||||
@@ -71,8 +70,8 @@ func TestTableMultipleCrud(t *testing.T) {
|
||||
assert.Nil(t, table.create(&record3))
|
||||
|
||||
// Read all records.
|
||||
var records []validRecord
|
||||
assert.Nil(t, table.getAll(&records))
|
||||
records, err := table.getAll()
|
||||
assert.Nil(t, err)
|
||||
if assert.Equal(t, 3, len(records)) {
|
||||
assert.Equal(t, record1, records[0])
|
||||
assert.Equal(t, record2, records[1])
|
||||
@@ -81,19 +80,19 @@ func TestTableMultipleCrud(t *testing.T) {
|
||||
|
||||
// Truncate the table and verify that the records no longer exist.
|
||||
assert.Nil(t, table.truncate())
|
||||
assert.Nil(t, table.getAll(&records))
|
||||
records, err = table.getAll()
|
||||
assert.Equal(t, 0, len(records))
|
||||
var record4 *validRecord
|
||||
if assert.Nil(t, table.getById(record1.Id, &record4)) {
|
||||
assert.Nil(t, record4)
|
||||
}
|
||||
assert.Nil(t, err)
|
||||
record4, err := table.getById(record1.Id)
|
||||
assert.Nil(t, record4)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestTableWithManualId(t *testing.T) {
|
||||
db := setupTestDb(t)
|
||||
defer db.Close()
|
||||
|
||||
table, err := db.newTable(manualIdRecord{})
|
||||
table, err := newTable[manualIdRecord](db)
|
||||
if !assert.Nil(t, err) {
|
||||
return
|
||||
}
|
||||
@@ -103,23 +102,22 @@ func TestTableWithManualId(t *testing.T) {
|
||||
if assert.Nil(t, table.create(&record)) {
|
||||
assert.Equal(t, 254, record.Id)
|
||||
}
|
||||
var record2 *manualIdRecord
|
||||
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||
assert.Equal(t, record, *record2)
|
||||
}
|
||||
record2, err := table.getById(record.Id)
|
||||
assert.Equal(t, record, *record2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Test update and then read back.
|
||||
record.StringData = "Teh Chezy Pofs"
|
||||
assert.Nil(t, table.update(&record))
|
||||
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||
assert.Equal(t, record, *record2)
|
||||
}
|
||||
record2, err = table.getById(record.Id)
|
||||
assert.Equal(t, record, *record2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Test delete.
|
||||
assert.Nil(t, table.delete(record.Id))
|
||||
if assert.Nil(t, table.getById(record.Id, &record2)) {
|
||||
assert.Nil(t, record2)
|
||||
}
|
||||
record2, err = table.getById(record.Id)
|
||||
assert.Nil(t, record2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Test creating a record with a zero ID.
|
||||
record.Id = 0
|
||||
@@ -136,7 +134,7 @@ func TestNewTableErrors(t *testing.T) {
|
||||
defer db.Close()
|
||||
|
||||
// Pass a non-struct as the record type.
|
||||
table, err := db.newTable(123)
|
||||
table, err := newTable[int](db)
|
||||
assert.Nil(t, table)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "record type must be a struct; got int", err.Error())
|
||||
@@ -146,8 +144,8 @@ func TestNewTableErrors(t *testing.T) {
|
||||
type recordWithNoId struct {
|
||||
StringData string
|
||||
}
|
||||
table, err = db.newTable(recordWithNoId{})
|
||||
assert.Nil(t, table)
|
||||
table2, err := newTable[recordWithNoId](db)
|
||||
assert.Nil(t, table2)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "struct recordWithNoId has no field tagged as the id", err.Error())
|
||||
}
|
||||
@@ -156,8 +154,8 @@ func TestNewTableErrors(t *testing.T) {
|
||||
type recordWithWrongIdType struct {
|
||||
Id bool `db:"id"`
|
||||
}
|
||||
table, err = db.newTable(recordWithWrongIdType{})
|
||||
assert.Nil(t, table)
|
||||
table3, err := newTable[recordWithWrongIdType](db)
|
||||
assert.Nil(t, table3)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(
|
||||
t, "field in struct recordWithWrongIdType tagged with 'id' must be an int; got bool", err.Error(),
|
||||
@@ -169,57 +167,13 @@ func TestTableCrudErrors(t *testing.T) {
|
||||
db := setupTestDb(t)
|
||||
defer db.Close()
|
||||
|
||||
table, err := db.newTable(validRecord{})
|
||||
table, err := newTable[validRecord](db)
|
||||
if !assert.Nil(t, err) {
|
||||
return
|
||||
}
|
||||
type differentRecord struct {
|
||||
StringData string
|
||||
}
|
||||
|
||||
// Pass an object of the wrong type when getting a single record.
|
||||
var record validRecord
|
||||
err = table.getById(record.Id, record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr; got a struct", err.Error())
|
||||
}
|
||||
err = table.getById(record.Id, &record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr -> ptr; got a ptr -> struct", err.Error())
|
||||
}
|
||||
var recordTriplePointer ***validRecord
|
||||
err = table.getById(record.Id, recordTriplePointer)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr -> ptr -> struct; got a ptr -> ptr -> ptr", err.Error())
|
||||
}
|
||||
var differentRecordPointer *differentRecord
|
||||
err = table.getById(record.Id, &differentRecordPointer)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(
|
||||
t,
|
||||
"given record of type model.differentRecord does not match expected type for table validRecord",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
// Pass an object of the wrong type when getting all records.
|
||||
var records []validRecord
|
||||
err = table.getAll(records)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr; got a slice", err.Error())
|
||||
}
|
||||
|
||||
// Pass an object of the wrong type when creating or updating a record.
|
||||
err = table.create(record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr; got a struct", err.Error())
|
||||
}
|
||||
err = table.update(record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "input must be a ptr; got a struct", err.Error())
|
||||
}
|
||||
|
||||
// Create a record with a non-zero ID.
|
||||
var record validRecord
|
||||
record.Id = 12345
|
||||
err = table.create(&record)
|
||||
if assert.NotNil(t, err) {
|
||||
@@ -249,11 +203,4 @@ func TestTableCrudErrors(t *testing.T) {
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "can't delete non-existent validRecord with ID 12345", err.Error())
|
||||
}
|
||||
|
||||
// Update a record with an incorrectly constructed table object.
|
||||
table.idFieldIndex = nil
|
||||
err = table.update(&record)
|
||||
if assert.NotNil(t, err) {
|
||||
assert.Equal(t, "struct validRecord has no field tagged as the id", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,9 +27,7 @@ func (database *Database) CreateTeam(team *Team) error {
|
||||
}
|
||||
|
||||
func (database *Database) GetTeamById(id int) (*Team, error) {
|
||||
var team *Team
|
||||
err := database.teamTable.getById(id, &team)
|
||||
return team, err
|
||||
return database.teamTable.getById(id)
|
||||
}
|
||||
|
||||
func (database *Database) UpdateTeam(team *Team) error {
|
||||
@@ -45,8 +43,8 @@ func (database *Database) TruncateTeams() error {
|
||||
}
|
||||
|
||||
func (database *Database) GetAllTeams() ([]Team, error) {
|
||||
var teams []Team
|
||||
if err := database.teamTable.getAll(&teams); err != nil {
|
||||
teams, err := database.teamTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(teams, func(i, j int) bool {
|
||||
|
||||
@@ -19,8 +19,8 @@ func (database *Database) CreateUserSession(session *UserSession) error {
|
||||
}
|
||||
|
||||
func (database *Database) GetUserSessionByToken(token string) (*UserSession, error) {
|
||||
var userSessions []UserSession
|
||||
if err := database.userSessionTable.getAll(&userSessions); err != nil {
|
||||
userSessions, err := database.userSessionTable.getAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user