polycule-connect/polyculeconnect/internal/db/client/client.go
Melora Hugues 8d805cefe6
Some checks failed
/ docker-build-only (push) Failing after 27s
/ go-test (push) Failing after 1m18s
Cleanup DB a bit, and start correctly handling users (#42)
2024-10-18 22:06:05 +02:00

93 lines
2.8 KiB
Go

package client
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model"
"git.faercol.me/faercol/polyculeconnect/polyculeconnect/logger"
_ "github.com/mattn/go-sqlite3"
)
var ErrNotFound = errors.New("not found")
const clientRows = `"client"."id", "client"."secret", "client"."redirect_uris", "client"."trusted_peers", "client"."name"`
type ClientDB interface {
GetClientByID(ctx context.Context, id string) (*model.Client, error)
AddClient(ctx context.Context, client *model.Client) error
}
type sqlClientDB struct {
db *sql.DB
}
func strArrayToSlice(rawVal string) []string {
var res []string
if err := json.Unmarshal([]byte(rawVal), &res); err != nil {
return nil
}
return res
}
func sliceToStrArray(rawVal []string) string {
res, err := json.Marshal(rawVal)
if err != nil {
return "[]"
}
return string(res)
}
func clientFromRow(row *sql.Row) (*model.Client, error) {
var res model.Client
redirectURIsStr := ""
trustedPeersStr := ""
if err := row.Scan(&res.ID, &res.Secret, &redirectURIsStr, &trustedPeersStr, &res.Name); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("invalid format for client: %w", err)
}
res.ClientConfig.RedirectURIs = strArrayToSlice(redirectURIsStr)
res.ClientConfig.TrustedPeers = strArrayToSlice(trustedPeersStr)
return &res, nil
}
func (db *sqlClientDB) GetClientByID(ctx context.Context, id string) (*model.Client, error) {
logger.L.Debugf("Getting client app with ID %s from DB", id)
query := fmt.Sprintf(`SELECT %s FROM "client" WHERE "id" = ?`, clientRows)
row := db.db.QueryRowContext(ctx, query, id)
return clientFromRow(row)
}
func (db *sqlClientDB) AddClient(ctx context.Context, client *model.Client) error {
logger.L.Debugf("Creating client %s", client.Name)
query := `INSERT INTO "client" ("id", "secret", "redirect_uris", "trusted_peers", "name") VALUES ($1, $2, $3, $4, $5)`
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
if affectedRows, err := tx.ExecContext(ctx, query, client.ID, client.Secret, sliceToStrArray(client.RedirectURIs()), sliceToStrArray(client.TrustedPeers), client.Name); err != nil {
return fmt.Errorf("failed to insert in DB: %w", err)
} else if nbAffected, err := affectedRows.RowsAffected(); err != nil {
return fmt.Errorf("failed to check number of affected rows: %w", err)
} else if nbAffected != 1 {
return fmt.Errorf("unexpected number of affected rows: %d", nbAffected)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
func New(db *sql.DB) *sqlClientDB {
return &sqlClientDB{db: db}
}