polycule-connect/polyculeconnect/internal/db/client/client.go

136 lines
3.9 KiB
Go
Raw Permalink Normal View History

package client
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
2024-10-06 09:28:26 +00:00
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
_ "github.com/mattn/go-sqlite3"
)
var ErrNotFound = errors.New("not found")
const clientRows = `"client"."id", "client"."secret", "client"."redirect_uris", "client"."trusted_peers", "client"."name"`
type ClientDB interface {
GetClientByID(ctx context.Context, id string) (*model.Client, error)
GetAllClients(ctx context.Context) ([]*model.Client, error)
AddClient(ctx context.Context, client *model.Client) error
DeleteClient(ctx context.Context, id string) error
}
type sqlClientDB struct {
db *sql.DB
}
func strArrayToSlice(rawVal string) []string {
var res []string
if err := json.Unmarshal([]byte(rawVal), &res); err != nil {
return nil
}
return res
}
func sliceToStrArray(rawVal []string) string {
res, err := json.Marshal(rawVal)
if err != nil {
return "[]"
}
return string(res)
}
type scannable interface {
Scan(dest ...any) error
}
func clientFromRow(row scannable) (*model.Client, error) {
var res model.Client
redirectURIsStr := ""
trustedPeersStr := ""
if err := row.Scan(&res.ID, &res.Secret, &redirectURIsStr, &trustedPeersStr, &res.Name); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("invalid format for client: %w", err)
}
res.ClientConfig.RedirectURIs = strArrayToSlice(redirectURIsStr)
res.ClientConfig.TrustedPeers = strArrayToSlice(trustedPeersStr)
return &res, nil
}
func (db *sqlClientDB) GetClientByID(ctx context.Context, id string) (*model.Client, error) {
2024-10-06 09:28:26 +00:00
logger.L.Debugf("Getting client app with ID %s from DB", id)
query := fmt.Sprintf(`SELECT %s FROM "client" WHERE "id" = ?`, clientRows)
row := db.db.QueryRowContext(ctx, query, id)
return clientFromRow(row)
}
func (db *sqlClientDB) GetAllClients(ctx context.Context) ([]*model.Client, error) {
rows, err := db.db.QueryContext(ctx, fmt.Sprintf(`SELECT %s FROM "client"`, clientRows))
if err != nil {
return nil, fmt.Errorf("failed to query clients from DB: %w", err)
}
var res []*model.Client
for rows.Next() {
clt, err := clientFromRow(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
res = append(res, clt)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("failed to read all rows %w", err)
}
return res, nil
}
func (db *sqlClientDB) AddClient(ctx context.Context, client *model.Client) error {
logger.L.Debugf("Creating client %s", client.Name)
query := `INSERT INTO "client" ("id", "secret", "redirect_uris", "trusted_peers", "name") VALUES ($1, $2, $3, $4, $5)`
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
if affectedRows, err := tx.ExecContext(ctx, query, client.ID, client.Secret, sliceToStrArray(client.RedirectURIs()), sliceToStrArray(client.TrustedPeers), client.Name); err != nil {
return fmt.Errorf("failed to insert in DB: %w", err)
} else if nbAffected, err := affectedRows.RowsAffected(); err != nil {
return fmt.Errorf("failed to check number of affected rows: %w", err)
} else if nbAffected != 1 {
return fmt.Errorf("unexpected number of affected rows: %d", nbAffected)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
func (db *sqlClientDB) DeleteClient(ctx context.Context, id string) error {
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
if _, err := tx.ExecContext(ctx, `DELETE FROM "client" WHERE "id" = ?`, id); err != nil {
return fmt.Errorf("failed to exec query: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
func New(db *sql.DB) *sqlClientDB {
return &sqlClientDB{db: db}
}