package skykit
import (
"errors"
"fmt"
"reflect"
"strings"
"time"
"github.com/google/uuid"
)
var ErrNotFound = errors.New("not found")
type Collection[E Entity] struct {
DB *Database
Ent E
Type reflect.Type
Table string
}
func Manage[E Entity](db *Database, table string, ent E) *Collection[E] {
db.Register(table, ent)
t := reflect.TypeOf(ent)
db.Cols[table] = &Collection[Entity]{db, ent, t, table}
return &Collection[E]{db, ent, t, table}
}
func (c *Collection[E]) New() E {
ent := reflect.New(c.Type.Elem()).Interface().(E)
ent.GetModel().SetDB(c.DB)
return ent
}
func (c *Collection[E]) Count(query string, args ...any) (count int) {
countQuery := `SELECT COUNT(*) FROM ` + c.Table
if query != "" {
countQuery += " " + query
}
c.DB.Query(countQuery, args...).Scan(&count)
return count
}
func (c *Collection[E]) First(query string, args ...any) (E, error) {
// Add LIMIT 1 if not already in query
if !strings.Contains(strings.ToUpper(query), "LIMIT") {
query += " LIMIT 1"
}
results, err := c.Search(query, args...)
if err != nil {
var zero E
return zero, err
}
if len(results) == 0 {
var zero E
return zero, fmt.Errorf("%w", ErrNotFound)
}
return results[0], nil
}
func (c *Collection[E]) Get(id string) (E, error) {
ent := c.New()
return ent, c.DB.Get(c.Table, id, ent)
}
func (c *Collection[E]) Insert(ent E) (E, error) {
ent.GetModel().SetDB(c.DB)
if ent.GetModel().ID == "" {
ent.GetModel().ID = uuid.NewString()
}
// Always set timestamps for new records
if ent.GetModel().CreatedAt.IsZero() {
ent.GetModel().CreatedAt = time.Now()
}
if ent.GetModel().UpdatedAt.IsZero() {
ent.GetModel().UpdatedAt = time.Now()
}
return ent, c.DB.Insert(c.Table, ent)
}
func (c *Collection[E]) Update(ent E) error {
ent.GetModel().UpdatedAt = time.Now()
return c.DB.Update(c.Table, ent)
}
func (c *Collection[E]) Delete(ent E) error {
return c.DB.Delete(c.Table, ent)
}
func (c *Collection[E]) Search(query string, args ...any) ([]E, error) {
apps := []E{}
return apps, Cursor(c.DB, c.Ent, c.Table, query, args...).
Iter(func(load func(Entity) error) error {
app := c.New()
if err := load(app); err != nil {
return err
}
apps = append(apps, app)
return nil
})
}
func (c *Collection[E]) Index(columns ...string) error {
return c.DB.Index(c.Table, columns...)
}
func (c *Collection[E]) UniqueIndex(columns ...string) error {
return c.DB.UniqueIndex(c.Table, columns...)
}