diff --git a/polyculeconnect/cmd/serve/serve.go b/polyculeconnect/cmd/serve/serve.go index 091effb..946c3cf 100644 --- a/polyculeconnect/cmd/serve/serve.go +++ b/polyculeconnect/cmd/serve/serve.go @@ -68,8 +68,9 @@ func serve() { st := storage.Storage{LocalStorage: userDB, InitializedBackends: backends, Key: &signingKey} opConf := op.Config{ - CryptoKey: key, - CodeMethodS256: false, + CryptoKey: key, + CodeMethodS256: false, + GrantTypeRefreshToken: true, } slogger := slog.New(zapslog.NewHandler(logger.L.Desugar().Core(), nil)) // slogger := diff --git a/polyculeconnect/internal/db/base.go b/polyculeconnect/internal/db/base.go index 54c79e2..0a61e13 100644 --- a/polyculeconnect/internal/db/base.go +++ b/polyculeconnect/internal/db/base.go @@ -9,6 +9,7 @@ import ( "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/authrequest" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/backend" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/client" + "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/token" "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/db/user" ) @@ -19,6 +20,7 @@ type Storage interface { AuthRequestStorage() authrequest.AuthRequestDB AuthCodeStorage() authcode.AuthCodeDB UserStorage() user.UserDB + TokenStorage() token.TokenDB } type sqlStorage struct { @@ -49,6 +51,10 @@ func (s *sqlStorage) UserStorage() user.UserDB { return user.New(s.db) } +func (s *sqlStorage) TokenStorage() token.TokenDB { + return token.New(s.db) +} + func New(conf config.AppConfig) (Storage, error) { db, err := sql.Open("sqlite3", conf.StorageConfig.File) if err != nil { diff --git a/polyculeconnect/internal/db/token/token.go b/polyculeconnect/internal/db/token/token.go new file mode 100644 index 0000000..0ba9142 --- /dev/null +++ b/polyculeconnect/internal/db/token/token.go @@ -0,0 +1,68 @@ +package token + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "git.faercol.me/faercol/polyculeconnect/polyculeconnect/internal/model" + "github.com/google/uuid" +) + +func strArrayToSlice(rawVal string) []string { + var res []string + if err := json.Unmarshal([]byte(rawVal), &res); err != nil { + return nil + } + return res +} + +func sliceToStrArray(rawVal []string) string { + res, err := json.Marshal(rawVal) + if err != nil { + return "[]" + } + return string(res) +} + +type TokenDB interface { + AddRefreshToken(ctx context.Context, refreshToken *model.RefreshToken) error + GetRefreshTokenByID(ctx context.Context, id uuid.UUID) (*model.RefreshToken, error) +} + +type sqlTokenDB struct { + db *sql.DB +} + +func (db *sqlTokenDB) AddRefreshToken(ctx context.Context, refreshToken *model.RefreshToken) error { + tx, err := db.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + if _, err := tx.ExecContext(ctx, `INSERT INTO "refresh_token" ("id", "client_id", "user_id", "scopes", "auth_time") VALUES ($1, $2, $3, $4, $5)`, refreshToken.ID, refreshToken.ClientID, refreshToken.UserID, sliceToStrArray(refreshToken.Scopes), refreshToken.AuthTime); err != nil { + return fmt.Errorf("failed to exec query: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + return nil +} + +func (db *sqlTokenDB) GetRefreshTokenByID(ctx context.Context, id uuid.UUID) (*model.RefreshToken, error) { + row := db.db.QueryRowContext(ctx, `SELECT "id", "client_id", "user_id", "scopes", "auth_time" FROM "refresh_token" WHERE "id" = ?`, id) + var res model.RefreshToken + var strScopes string + if err := row.Scan(&res.ID, &res.ClientID, &res.UserID, &strScopes, &res.AuthTime); err != nil { + return nil, fmt.Errorf("failed to query DB: %w", err) + } + res.Scopes = strArrayToSlice(strScopes) + return &res, nil +} + +func New(db *sql.DB) TokenDB { + return &sqlTokenDB{db: db} +} diff --git a/polyculeconnect/internal/model/token.go b/polyculeconnect/internal/model/token.go index afde439..9ba5688 100644 --- a/polyculeconnect/internal/model/token.go +++ b/polyculeconnect/internal/model/token.go @@ -17,5 +17,52 @@ type Token struct { type RefreshToken struct { ID uuid.UUID + ClientID string + UserID string + Scopes []string AuthTime time.Time } + +func (t RefreshToken) Request() *RefreshTokenRequest { + return &RefreshTokenRequest{ + userID: t.UserID, + clientID: t.ClientID, + scopes: t.Scopes, + authTime: t.AuthTime, + } +} + +type RefreshTokenRequest struct { + clientID string + authTime time.Time + userID string + scopes []string +} + +func (r RefreshTokenRequest) GetAMR() []string { + return []string{} +} + +func (r RefreshTokenRequest) GetAudience() []string { + return []string{} +} + +func (r RefreshTokenRequest) GetAuthTime() time.Time { + return r.authTime +} + +func (r RefreshTokenRequest) GetClientID() string { + return r.clientID +} + +func (r RefreshTokenRequest) GetScopes() []string { + return r.scopes +} + +func (r RefreshTokenRequest) GetSubject() string { + return r.userID +} + +func (r *RefreshTokenRequest) SetCurrentScopes(scopes []string) { + r.scopes = scopes +} diff --git a/polyculeconnect/internal/storage/storage.go b/polyculeconnect/internal/storage/storage.go index fd32649..dbe4777 100644 --- a/polyculeconnect/internal/storage/storage.go +++ b/polyculeconnect/internal/storage/storage.go @@ -2,6 +2,7 @@ package storage import ( "context" + "database/sql" "errors" "fmt" "time" @@ -135,14 +136,20 @@ func (s *Storage) DeleteAuthRequest(ctx context.Context, id string) error { func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (accessTokenID string, expiration time.Time, err error) { accessTokenUUID := uuid.New() + var authTime time.Time - // we are expecting our own request model - authRequest, ok := req.(*model.AuthRequest) - if !ok { + switch typedReq := req.(type) { + case *model.AuthRequest: + logger.L.Debug("Creating access token for new authentication") + authTime = typedReq.AuthTime + case *model.RefreshTokenRequest: + logger.L.Debug("Handling refresh token request") + authTime = typedReq.GetAuthTime() + default: + logger.L.Errorf("Unexpected type for request %v", err) return "", time.Time{}, errors.New("failed to parse auth request") } - authTime := authRequest.AuthTime.UTC() expiration = authTime.Add(5 * time.Minute) // token := model.Token{ @@ -160,14 +167,23 @@ func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (a func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error) { accessTokenUUID := uuid.New() refreshTokenUUID := uuid.New() + var authTime time.Time + var clientID string - // we are expecting our own request model - authRequest, ok := request.(*model.AuthRequest) - if !ok { + switch typedReq := request.(type) { + case *model.AuthRequest: + logger.L.Debug("Creating access token for new authentication") + clientID = typedReq.ClientID + authTime = typedReq.AuthTime + case *model.RefreshTokenRequest: + logger.L.Debug("Handling refresh token request") + clientID = typedReq.GetClientID() + authTime = typedReq.GetAuthTime() + default: + logger.L.Errorf("Unexpected type for request %v", err) return "", "", time.Time{}, errors.New("failed to parse auth request") } - authTime := authRequest.AuthTime.UTC() expiration = authTime.Add(5 * time.Minute) // token := model.Token{ @@ -178,12 +194,34 @@ func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.T // Audiences: request.GetAudience(), // Scopes: request.GetScopes(), // } + refreshToken := model.RefreshToken{ + ID: refreshTokenUUID, + ClientID: clientID, + UserID: request.GetSubject(), + Scopes: request.GetScopes(), + AuthTime: authTime, + } + if err := s.LocalStorage.TokenStorage().AddRefreshToken(ctx, &refreshToken); err != nil { + return "", "", time.Time{}, fmt.Errorf("failed to insert token in DB: %w", err) + } return accessTokenUUID.String(), refreshTokenUUID.String(), expiration, nil } func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (op.RefreshTokenRequest, error) { - return nil, ErrNotImplemented("TokenRequestByRefreshToken") + parsedID, err := uuid.Parse(refreshTokenID) + if err != nil { + return nil, fmt.Errorf("invalid format for refresh token id: %w", err) + } + + refreshToken, err := s.LocalStorage.TokenStorage().GetRefreshTokenByID(ctx, parsedID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, op.ErrInvalidRefreshToken + } + return nil, fmt.Errorf("failed to get refresh token: %w", err) + } + return refreshToken.Request(), nil } func (s *Storage) TerminateSession(ctx context.Context, userID string, clientID string) error { diff --git a/polyculeconnect/migrations/1_tokens.down.sql b/polyculeconnect/migrations/1_tokens.down.sql new file mode 100644 index 0000000..3e12ff9 --- /dev/null +++ b/polyculeconnect/migrations/1_tokens.down.sql @@ -0,0 +1 @@ +DROP TABLE refresh_token; diff --git a/polyculeconnect/migrations/1_tokens.up.sql b/polyculeconnect/migrations/1_tokens.up.sql new file mode 100644 index 0000000..818e48d --- /dev/null +++ b/polyculeconnect/migrations/1_tokens.up.sql @@ -0,0 +1,9 @@ +CREATE TABLE refresh_token ( + id TEXT NOT NULL PRIMARY KEY, + client_id TEXT NOT NULL, + user_id TEXT NOT NULL, + scopes blob NOT NULL, -- list of strings, json-encoded + auth_time timestamp NOT NULL, + FOREIGN KEY(client_id) REFERENCES client(id), + FOREIGN KEY(user_id) REFERENCES user(id) +); diff --git a/polyculeconnect/polyculeconnect.db b/polyculeconnect/polyculeconnect.db index 9430bab..41e982e 100644 Binary files a/polyculeconnect/polyculeconnect.db and b/polyculeconnect/polyculeconnect.db differ