diff --git a/go.mod b/go.mod index d523d6f..fe377fd 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,10 @@ module github.com/Team254/cheesy-arena-lite -go 1.16 +go 1.18 require ( github.com/dchest/uniuri v0.0.0-20200228104902-7aecb25e1fe5 github.com/goburrow/modbus v0.1.0 - github.com/goburrow/serial v0.1.0 // indirect github.com/google/uuid v1.2.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 @@ -15,3 +14,11 @@ require ( go.etcd.io/bbolt v1.3.5 golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 ) + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/goburrow/serial v0.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) diff --git a/model/alliance_team.go b/model/alliance_team.go index d10bf4e..64379b2 100644 --- a/model/alliance_team.go +++ b/model/alliance_team.go @@ -19,8 +19,8 @@ func (database *Database) CreateAllianceTeam(allianceTeam *AllianceTeam) error { } func (database *Database) GetTeamsByAlliance(allianceId int) ([]AllianceTeam, error) { - var allianceTeams []AllianceTeam - if err := database.allianceTeamTable.getAll(&allianceTeams); err != nil { + allianceTeams, err := database.allianceTeamTable.getAll() + if err != nil { return nil, err } sort.Slice(allianceTeams, func(i, j int) bool { @@ -49,8 +49,8 @@ func (database *Database) TruncateAllianceTeams() error { } func (database *Database) GetAllAlliances() ([][]AllianceTeam, error) { - var allianceTeams []AllianceTeam - if err := database.allianceTeamTable.getAll(&allianceTeams); err != nil { + allianceTeams, err := database.allianceTeamTable.getAll() + if err != nil { return nil, err } sort.Slice(allianceTeams, func(i, j int) bool { diff --git a/model/award.go b/model/award.go index fd31a2a..b404121 100644 --- a/model/award.go +++ b/model/award.go @@ -28,9 +28,7 @@ func (database *Database) CreateAward(award *Award) error { } func (database *Database) GetAwardById(id int) (*Award, error) { - var award *Award - err := database.awardTable.getById(id, &award) - return award, err + return database.awardTable.getById(id) } func (database *Database) UpdateAward(award *Award) error { @@ -46,8 +44,8 @@ func (database *Database) TruncateAwards() error { } func (database *Database) GetAllAwards() ([]Award, error) { - var awards []Award - if err := database.awardTable.getAll(&awards); err != nil { + awards, err := database.awardTable.getAll() + if err != nil { return nil, err } sort.Slice(awards, func(i, j int) bool { diff --git a/model/database.go b/model/database.go index b813ab6..38818d9 100644 --- a/model/database.go +++ b/model/database.go @@ -23,17 +23,17 @@ var BaseDir = "." // Mutable for testing type Database struct { Path string bolt *bbolt.DB - allianceTeamTable *table - awardTable *table - eventSettingsTable *table - lowerThirdTable *table - matchTable *table - matchResultTable *table - rankingTable *table - scheduleBlockTable *table - sponsorSlideTable *table - teamTable *table - userSessionTable *table + allianceTeamTable *table[AllianceTeam] + awardTable *table[Award] + eventSettingsTable *table[EventSettings] + lowerThirdTable *table[LowerThird] + matchTable *table[Match] + matchResultTable *table[MatchResult] + rankingTable *table[game.Ranking] + scheduleBlockTable *table[ScheduleBlock] + sponsorSlideTable *table[SponsorSlide] + teamTable *table[Team] + userSessionTable *table[UserSession] } // Opens the Bolt database at the given path, creating it if it doesn't exist. @@ -46,37 +46,37 @@ func OpenDatabase(filename string) (*Database, error) { } // Register tables. - if database.allianceTeamTable, err = database.newTable(AllianceTeam{}); err != nil { + if database.allianceTeamTable, err = newTable[AllianceTeam](&database); err != nil { return nil, err } - if database.awardTable, err = database.newTable(Award{}); err != nil { + if database.awardTable, err = newTable[Award](&database); err != nil { return nil, err } - if database.eventSettingsTable, err = database.newTable(EventSettings{}); err != nil { + if database.eventSettingsTable, err = newTable[EventSettings](&database); err != nil { return nil, err } - if database.lowerThirdTable, err = database.newTable(LowerThird{}); err != nil { + if database.lowerThirdTable, err = newTable[LowerThird](&database); err != nil { return nil, err } - if database.matchTable, err = database.newTable(Match{}); err != nil { + if database.matchTable, err = newTable[Match](&database); err != nil { return nil, err } - if database.matchResultTable, err = database.newTable(MatchResult{}); err != nil { + if database.matchResultTable, err = newTable[MatchResult](&database); err != nil { return nil, err } - if database.rankingTable, err = database.newTable(game.Ranking{}); err != nil { + if database.rankingTable, err = newTable[game.Ranking](&database); err != nil { return nil, err } - if database.scheduleBlockTable, err = database.newTable(ScheduleBlock{}); err != nil { + if database.scheduleBlockTable, err = newTable[ScheduleBlock](&database); err != nil { return nil, err } - if database.sponsorSlideTable, err = database.newTable(SponsorSlide{}); err != nil { + if database.sponsorSlideTable, err = newTable[SponsorSlide](&database); err != nil { return nil, err } - if database.teamTable, err = database.newTable(Team{}); err != nil { + if database.teamTable, err = newTable[Team](&database); err != nil { return nil, err } - if database.userSessionTable, err = database.newTable(UserSession{}); err != nil { + if database.userSessionTable, err = newTable[UserSession](&database); err != nil { return nil, err } diff --git a/model/event_settings.go b/model/event_settings.go index 62e40de..fd82514 100755 --- a/model/event_settings.go +++ b/model/event_settings.go @@ -41,8 +41,8 @@ type EventSettings struct { } func (database *Database) GetEventSettings() (*EventSettings, error) { - var allEventSettings []EventSettings - if err := database.eventSettingsTable.getAll(&allEventSettings); err != nil { + allEventSettings, err := database.eventSettingsTable.getAll() + if err != nil { return nil, err } if len(allEventSettings) == 1 { diff --git a/model/lower_third.go b/model/lower_third.go index 0cb345e..295ca12 100644 --- a/model/lower_third.go +++ b/model/lower_third.go @@ -22,9 +22,7 @@ func (database *Database) CreateLowerThird(lowerThird *LowerThird) error { } func (database *Database) GetLowerThirdById(id int) (*LowerThird, error) { - var lowerThird *LowerThird - err := database.lowerThirdTable.getById(id, &lowerThird) - return lowerThird, err + return database.lowerThirdTable.getById(id) } func (database *Database) UpdateLowerThird(lowerThird *LowerThird) error { @@ -40,8 +38,8 @@ func (database *Database) TruncateLowerThirds() error { } func (database *Database) GetAllLowerThirds() ([]LowerThird, error) { - var lowerThirds []LowerThird - if err := database.lowerThirdTable.getAll(&lowerThirds); err != nil { + lowerThirds, err := database.lowerThirdTable.getAll() + if err != nil { return nil, err } sort.Slice(lowerThirds, func(i, j int) bool { diff --git a/model/match.go b/model/match.go index 2efdb20..3350506 100644 --- a/model/match.go +++ b/model/match.go @@ -55,9 +55,7 @@ func (database *Database) CreateMatch(match *Match) error { } func (database *Database) GetMatchById(id int) (*Match, error) { - var match *Match - err := database.matchTable.getById(id, &match) - return match, err + return database.matchTable.getById(id) } func (database *Database) UpdateMatch(match *Match) error { @@ -73,8 +71,8 @@ func (database *Database) TruncateMatches() error { } func (database *Database) GetMatchByName(matchType string, displayName string) (*Match, error) { - var matches []Match - if err := database.matchTable.getAll(&matches); err != nil { + matches, err := database.matchTable.getAll() + if err != nil { return nil, err } @@ -102,8 +100,8 @@ func (database *Database) GetMatchesByElimRoundGroup(round int, group int) ([]Ma } func (database *Database) GetMatchesByType(matchType string) ([]Match, error) { - var matches []Match - if err := database.matchTable.getAll(&matches); err != nil { + matches, err := database.matchTable.getAll() + if err != nil { return nil, err } diff --git a/model/match_result.go b/model/match_result.go index c29713b..274f3a9 100755 --- a/model/match_result.go +++ b/model/match_result.go @@ -31,8 +31,8 @@ func (database *Database) CreateMatchResult(matchResult *MatchResult) error { } func (database *Database) GetMatchResultForMatch(matchId int) (*MatchResult, error) { - var matchResults []MatchResult - if err := database.matchResultTable.getAll(&matchResults); err != nil { + matchResults, err := database.matchResultTable.getAll() + if err != nil { return nil, err } diff --git a/model/ranking.go b/model/ranking.go index e88c55b..14146b3 100644 --- a/model/ranking.go +++ b/model/ranking.go @@ -15,9 +15,7 @@ func (database *Database) CreateRanking(ranking *game.Ranking) error { } func (database *Database) GetRankingForTeam(teamId int) (*game.Ranking, error) { - var ranking *game.Ranking - err := database.rankingTable.getById(teamId, &ranking) - return ranking, err + return database.rankingTable.getById(teamId) } func (database *Database) UpdateRanking(ranking *game.Ranking) error { @@ -33,8 +31,8 @@ func (database *Database) TruncateRankings() error { } func (database *Database) GetAllRankings() (game.Rankings, error) { - var rankings []game.Ranking - if err := database.rankingTable.getAll(&rankings); err != nil { + rankings, err := database.rankingTable.getAll() + if err != nil { return nil, err } sort.Slice(rankings, func(i, j int) bool { diff --git a/model/schedule_block.go b/model/schedule_block.go index 6277e67..70cef2b 100644 --- a/model/schedule_block.go +++ b/model/schedule_block.go @@ -23,8 +23,8 @@ func (database *Database) CreateScheduleBlock(block *ScheduleBlock) error { } func (database *Database) GetScheduleBlocksByMatchType(matchType string) ([]ScheduleBlock, error) { - var scheduleBlocks []ScheduleBlock - if err := database.scheduleBlockTable.getAll(&scheduleBlocks); err != nil { + scheduleBlocks, err := database.scheduleBlockTable.getAll() + if err != nil { return nil, err } diff --git a/model/sponsor_slide.go b/model/sponsor_slide.go index 92b58ba..9d7c7d3 100644 --- a/model/sponsor_slide.go +++ b/model/sponsor_slide.go @@ -22,9 +22,7 @@ func (database *Database) CreateSponsorSlide(sponsorSlide *SponsorSlide) error { } func (database *Database) GetSponsorSlideById(id int) (*SponsorSlide, error) { - var sponsorSlide *SponsorSlide - err := database.sponsorSlideTable.getById(id, &sponsorSlide) - return sponsorSlide, err + return database.sponsorSlideTable.getById(id) } func (database *Database) UpdateSponsorSlide(sponsorSlide *SponsorSlide) error { @@ -40,8 +38,8 @@ func (database *Database) TruncateSponsorSlides() error { } func (database *Database) GetAllSponsorSlides() ([]SponsorSlide, error) { - var sponsorSlides []SponsorSlide - if err := database.sponsorSlideTable.getAll(&sponsorSlides); err != nil { + sponsorSlides, err := database.sponsorSlideTable.getAll() + if err != nil { return nil, err } sort.Slice(sponsorSlides, func(i, j int) bool { diff --git a/model/table.go b/model/table.go index 6753a5d..e9dde6a 100644 --- a/model/table.go +++ b/model/table.go @@ -15,7 +15,7 @@ import ( ) // Encapsulates all persistence operations for a particular data type represented by a struct. -type table struct { +type table[R any] struct { bolt *bbolt.DB recordType reflect.Type name string @@ -24,14 +24,15 @@ type table struct { manualId bool } -// Registers a new table for a struct, given its zero value. -func (database *Database) newTable(recordType interface{}) (*table, error) { +// Registers a new table for a struct. +func newTable[R any](database *Database) (*table[R], error) { + var recordType R recordTypeValue := reflect.ValueOf(recordType) if recordTypeValue.Kind() != reflect.Struct { return nil, fmt.Errorf("record type must be a struct; got %v", recordTypeValue.Kind()) } - var table table + var table table[R] table.bolt = database.bolt table.recordType = reflect.TypeOf(recordType) table.name = table.recordType.Name() @@ -75,14 +76,10 @@ func (database *Database) newTable(recordType interface{}) (*table, error) { return &table, nil } -// Populates the given double pointer to a record with the data from the record with the given ID, or nil if it doesn't -// exist. -func (table *table) getById(id int, record interface{}) error { - if err := table.validateType(record, reflect.Ptr, reflect.Ptr, reflect.Struct); err != nil { - return err - } - - return table.bolt.View(func(tx *bbolt.Tx) error { +// Returns the record with the given ID, or nil if it doesn't exist. +func (table *table[R]) getById(id int) (*R, error) { + record := new(R) + err := table.bolt.View(func(tx *bbolt.Tx) error { bucket, err := table.getBucket(tx) if err != nil { return err @@ -93,45 +90,36 @@ func (table *table) getById(id int, record interface{}) error { } // If the record does not exist, set the record pointer to nil. - recordPointerValue := reflect.ValueOf(record).Elem() - recordPointerValue.Set(reflect.Zero(recordPointerValue.Type())) - + record = nil return nil }) + return record, err } -// Populates the given slice passed by pointer with the data from every record in the table, ordered by ID. -func (table *table) getAll(recordSlice interface{}) error { - if err := table.validateType(recordSlice, reflect.Ptr, reflect.Slice, reflect.Struct); err != nil { - return err - } - - return table.bolt.View(func(tx *bbolt.Tx) error { +// Returns a slice containing every record in the table, ordered by string representation of ID. +func (table *table[R]) getAll() ([]R, error) { + records := []R{} + err := table.bolt.View(func(tx *bbolt.Tx) error { bucket, err := table.getBucket(tx) if err != nil { return err } - recordSliceValue := reflect.ValueOf(recordSlice).Elem() - recordSliceValue.Set(reflect.MakeSlice(recordSliceValue.Type(), 0, 0)) return bucket.ForEach(func(key, value []byte) error { - record := reflect.New(table.recordType) - err := json.Unmarshal(value, record.Interface()) + var record R + err := json.Unmarshal(value, &record) if err != nil { return err } - recordSliceValue.Set(reflect.Append(recordSliceValue, record.Elem())) + records = append(records, record) return nil }) }) + return records, err } // Persists the given record as a new row in the table. -func (table *table) create(record interface{}) error { - if err := table.validateType(record, reflect.Ptr, reflect.Struct); err != nil { - return err - } - +func (table *table[R]) create(record *R) error { // Validate that the record has its ID set to zero or not as expected, depending on whether it is configured for // autogenerated IDs. value := reflect.ValueOf(record).Elem() @@ -177,11 +165,7 @@ func (table *table) create(record interface{}) error { // Persists the given record as an update to the existing row in the table. Returns an error if the record does not // already exist. -func (table *table) update(record interface{}) error { - if err := table.validateType(record, reflect.Ptr, reflect.Struct); err != nil { - return err - } - +func (table *table[R]) update(record *R) error { // Validate that the record has a non-zero ID. value := reflect.ValueOf(record).Elem() id := int(value.Field(*table.idFieldIndex).Int()) @@ -211,7 +195,7 @@ func (table *table) update(record interface{}) error { } // Deletes the record having the given ID from the table. Returns an error if the record does not exist. -func (table *table) delete(id int) error { +func (table *table[R]) delete(id int) error { return table.bolt.Update(func(tx *bbolt.Tx) error { bucket, err := table.getBucket(tx) if err != nil { @@ -230,7 +214,7 @@ func (table *table) delete(id int) error { } // Deletes all records from the table. -func (table *table) truncate() error { +func (table *table[R]) truncate() error { return table.bolt.Update(func(tx *bbolt.Tx) error { _, err := table.getBucket(tx) if err != nil { @@ -248,7 +232,7 @@ func (table *table) truncate() error { } // Obtains the Bolt bucket belonging to the table. -func (table *table) getBucket(tx *bbolt.Tx) (*bbolt.Bucket, error) { +func (table *table[R]) getBucket(tx *bbolt.Tx) (*bbolt.Bucket, error) { bucket := tx.Bucket(table.bucketKey) if bucket == nil { return nil, fmt.Errorf("unknown table %s", table.name) @@ -256,39 +240,6 @@ func (table *table) getBucket(tx *bbolt.Tx) (*bbolt.Bucket, error) { return bucket, nil } -// Validates that the given record is of the expected derived type (e.g. pointer, slice, etc.), that the base type is -// the same as that stored in the table, and that the table is configured correctly. -func (table *table) validateType(record interface{}, kinds ...reflect.Kind) error { - // Check the hierarchy of kinds against the expected list until reaching the base record type. - recordType := reflect.ValueOf(record).Type() - expectedKind := "" - actualKind := "" - for i, kind := range kinds { - if i > 0 { - expectedKind += " -> " - actualKind += " -> " - } - expectedKind += kind.String() - actualKind += recordType.Kind().String() - if recordType.Kind() != kind { - return fmt.Errorf("input must be a %s; got a %s", expectedKind, actualKind) - } - if i < len(kinds)-1 { - recordType = recordType.Elem() - } - } - - if recordType != table.recordType { - return fmt.Errorf("given record of type %s does not match expected type for table %s", recordType, table.name) - } - - if table.idFieldIndex == nil { - return fmt.Errorf("struct %s has no field tagged as the id", table.name) - } - - return nil -} - // Serializes the given integer ID to a byte array containing its Base-10 string representation. func idToKey(id int) []byte { return []byte(strconv.Itoa(id)) diff --git a/model/table_test.go b/model/table_test.go index 9b468ca..a22e403 100644 --- a/model/table_test.go +++ b/model/table_test.go @@ -23,7 +23,7 @@ func TestTableSingleCrud(t *testing.T) { db := setupTestDb(t) defer db.Close() - table, err := db.newTable(validRecord{}) + table, err := newTable[validRecord](db) if !assert.Nil(t, err) { return } @@ -33,31 +33,30 @@ func TestTableSingleCrud(t *testing.T) { if assert.Nil(t, table.create(&record)) { assert.Equal(t, 1, record.Id) } - var record2 *validRecord - if assert.Nil(t, table.getById(record.Id, &record2)) { - assert.Equal(t, record, *record2) - } + record2, err := table.getById(record.Id) + assert.Equal(t, record, *record2) + assert.Nil(t, err) // Test update and then read back. record.IntData = 252 record.StringData = "Teh Chezy Pofs" assert.Nil(t, table.update(&record)) - if assert.Nil(t, table.getById(record.Id, &record2)) { - assert.Equal(t, record, *record2) - } + record2, err = table.getById(record.Id) + assert.Equal(t, record, *record2) + assert.Nil(t, err) // Test delete. assert.Nil(t, table.delete(record.Id)) - if assert.Nil(t, table.getById(record.Id, &record2)) { - assert.Nil(t, record2) - } + record2, err = table.getById(record.Id) + assert.Nil(t, record2) + assert.Nil(t, err) } func TestTableMultipleCrud(t *testing.T) { db := setupTestDb(t) defer db.Close() - table, err := db.newTable(validRecord{}) + table, err := newTable[validRecord](db) if !assert.Nil(t, err) { return } @@ -71,8 +70,8 @@ func TestTableMultipleCrud(t *testing.T) { assert.Nil(t, table.create(&record3)) // Read all records. - var records []validRecord - assert.Nil(t, table.getAll(&records)) + records, err := table.getAll() + assert.Nil(t, err) if assert.Equal(t, 3, len(records)) { assert.Equal(t, record1, records[0]) assert.Equal(t, record2, records[1]) @@ -81,19 +80,19 @@ func TestTableMultipleCrud(t *testing.T) { // Truncate the table and verify that the records no longer exist. assert.Nil(t, table.truncate()) - assert.Nil(t, table.getAll(&records)) + records, err = table.getAll() assert.Equal(t, 0, len(records)) - var record4 *validRecord - if assert.Nil(t, table.getById(record1.Id, &record4)) { - assert.Nil(t, record4) - } + assert.Nil(t, err) + record4, err := table.getById(record1.Id) + assert.Nil(t, record4) + assert.Nil(t, err) } func TestTableWithManualId(t *testing.T) { db := setupTestDb(t) defer db.Close() - table, err := db.newTable(manualIdRecord{}) + table, err := newTable[manualIdRecord](db) if !assert.Nil(t, err) { return } @@ -103,23 +102,22 @@ func TestTableWithManualId(t *testing.T) { if assert.Nil(t, table.create(&record)) { assert.Equal(t, 254, record.Id) } - var record2 *manualIdRecord - if assert.Nil(t, table.getById(record.Id, &record2)) { - assert.Equal(t, record, *record2) - } + record2, err := table.getById(record.Id) + assert.Equal(t, record, *record2) + assert.Nil(t, err) // Test update and then read back. record.StringData = "Teh Chezy Pofs" assert.Nil(t, table.update(&record)) - if assert.Nil(t, table.getById(record.Id, &record2)) { - assert.Equal(t, record, *record2) - } + record2, err = table.getById(record.Id) + assert.Equal(t, record, *record2) + assert.Nil(t, err) // Test delete. assert.Nil(t, table.delete(record.Id)) - if assert.Nil(t, table.getById(record.Id, &record2)) { - assert.Nil(t, record2) - } + record2, err = table.getById(record.Id) + assert.Nil(t, record2) + assert.Nil(t, err) // Test creating a record with a zero ID. record.Id = 0 @@ -136,7 +134,7 @@ func TestNewTableErrors(t *testing.T) { defer db.Close() // Pass a non-struct as the record type. - table, err := db.newTable(123) + table, err := newTable[int](db) assert.Nil(t, table) if assert.NotNil(t, err) { assert.Equal(t, "record type must be a struct; got int", err.Error()) @@ -146,8 +144,8 @@ func TestNewTableErrors(t *testing.T) { type recordWithNoId struct { StringData string } - table, err = db.newTable(recordWithNoId{}) - assert.Nil(t, table) + table2, err := newTable[recordWithNoId](db) + assert.Nil(t, table2) if assert.NotNil(t, err) { assert.Equal(t, "struct recordWithNoId has no field tagged as the id", err.Error()) } @@ -156,8 +154,8 @@ func TestNewTableErrors(t *testing.T) { type recordWithWrongIdType struct { Id bool `db:"id"` } - table, err = db.newTable(recordWithWrongIdType{}) - assert.Nil(t, table) + table3, err := newTable[recordWithWrongIdType](db) + assert.Nil(t, table3) if assert.NotNil(t, err) { assert.Equal( t, "field in struct recordWithWrongIdType tagged with 'id' must be an int; got bool", err.Error(), @@ -169,57 +167,13 @@ func TestTableCrudErrors(t *testing.T) { db := setupTestDb(t) defer db.Close() - table, err := db.newTable(validRecord{}) + table, err := newTable[validRecord](db) if !assert.Nil(t, err) { return } - type differentRecord struct { - StringData string - } - - // Pass an object of the wrong type when getting a single record. - var record validRecord - err = table.getById(record.Id, record) - if assert.NotNil(t, err) { - assert.Equal(t, "input must be a ptr; got a struct", err.Error()) - } - err = table.getById(record.Id, &record) - if assert.NotNil(t, err) { - assert.Equal(t, "input must be a ptr -> ptr; got a ptr -> struct", err.Error()) - } - var recordTriplePointer ***validRecord - err = table.getById(record.Id, recordTriplePointer) - if assert.NotNil(t, err) { - assert.Equal(t, "input must be a ptr -> ptr -> struct; got a ptr -> ptr -> ptr", err.Error()) - } - var differentRecordPointer *differentRecord - err = table.getById(record.Id, &differentRecordPointer) - if assert.NotNil(t, err) { - assert.Equal( - t, - "given record of type model.differentRecord does not match expected type for table validRecord", - err.Error(), - ) - } - - // Pass an object of the wrong type when getting all records. - var records []validRecord - err = table.getAll(records) - if assert.NotNil(t, err) { - assert.Equal(t, "input must be a ptr; got a slice", err.Error()) - } - - // Pass an object of the wrong type when creating or updating a record. - err = table.create(record) - if assert.NotNil(t, err) { - assert.Equal(t, "input must be a ptr; got a struct", err.Error()) - } - err = table.update(record) - if assert.NotNil(t, err) { - assert.Equal(t, "input must be a ptr; got a struct", err.Error()) - } // Create a record with a non-zero ID. + var record validRecord record.Id = 12345 err = table.create(&record) if assert.NotNil(t, err) { @@ -249,11 +203,4 @@ func TestTableCrudErrors(t *testing.T) { if assert.NotNil(t, err) { assert.Equal(t, "can't delete non-existent validRecord with ID 12345", err.Error()) } - - // Update a record with an incorrectly constructed table object. - table.idFieldIndex = nil - err = table.update(&record) - if assert.NotNil(t, err) { - assert.Equal(t, "struct validRecord has no field tagged as the id", err.Error()) - } } diff --git a/model/team.go b/model/team.go index c78384d..9c96fea 100644 --- a/model/team.go +++ b/model/team.go @@ -27,9 +27,7 @@ func (database *Database) CreateTeam(team *Team) error { } func (database *Database) GetTeamById(id int) (*Team, error) { - var team *Team - err := database.teamTable.getById(id, &team) - return team, err + return database.teamTable.getById(id) } func (database *Database) UpdateTeam(team *Team) error { @@ -45,8 +43,8 @@ func (database *Database) TruncateTeams() error { } func (database *Database) GetAllTeams() ([]Team, error) { - var teams []Team - if err := database.teamTable.getAll(&teams); err != nil { + teams, err := database.teamTable.getAll() + if err != nil { return nil, err } sort.Slice(teams, func(i, j int) bool { diff --git a/model/user_session.go b/model/user_session.go index c835eab..7ec8c0f 100644 --- a/model/user_session.go +++ b/model/user_session.go @@ -19,8 +19,8 @@ func (database *Database) CreateUserSession(session *UserSession) error { } func (database *Database) GetUserSessionByToken(token string) (*UserSession, error) { - var userSessions []UserSession - if err := database.userSessionTable.getAll(&userSessions); err != nil { + userSessions, err := database.userSessionTable.getAll() + if err != nil { return nil, err }