diff --git a/polyculeconnect/internal/middlewares/connectordispatcher.go b/polyculeconnect/internal/middlewares/connectordispatcher.go new file mode 100644 index 0000000..a6c5408 --- /dev/null +++ b/polyculeconnect/internal/middlewares/connectordispatcher.go @@ -0,0 +1,75 @@ +package middlewares + +import ( + "bytes" + "context" + "fmt" + "html/template" + "io" + "net/http" + "path/filepath" + + "go.uber.org/zap" +) + +const ( + backendNameQueryParam = "connector_id" + backendCtxKeyName = "backendName" +) + +type BackendFromRequestMiddleware struct { + l *zap.SugaredLogger + h http.Handler + baseDir string +} + +func (m *BackendFromRequestMiddleware) serveBackendSelector(w http.ResponseWriter, r *http.Request) (int, error) { + lp := filepath.Join(m.baseDir, "templates", "login.html") + hdrTpl := filepath.Join(m.baseDir, "templates", "header.html") + footTpl := filepath.Join(m.baseDir, "templates", "footer.html") + tmpl, err := template.New("login.html").ParseFiles(hdrTpl, footTpl, lp) + if err != nil { + return http.StatusInternalServerError, fmt.Errorf("failed to init template: %w", err) + } + buf := new(bytes.Buffer) + + if err := tmpl.Execute(buf, nil); err != nil { + return http.StatusInternalServerError, fmt.Errorf("failed to execute template: %w", err) + } + _, err = io.Copy(w, buf) + if err != nil { + return http.StatusInternalServerError, fmt.Errorf("failed to write response; %w", err) + } + + return http.StatusOK, nil +} + +func (m *BackendFromRequestMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/authorize" { + m.h.ServeHTTP(w, r) + return + } + + if err := r.ParseForm(); err != nil { + // TODO: handle this better + w.WriteHeader(http.StatusBadRequest) + return + } + + backendName := r.Form.Get(backendNameQueryParam) + if backendName == "" { + statusCode, err := m.serveBackendSelector(w, r) + if err != nil { + m.l.Errorf("Failed to serve backend selector page: %s", err) + } + w.WriteHeader(statusCode) + return + } + + ctx := context.WithValue(r.Context(), backendCtxKeyName, backendName) + m.h.ServeHTTP(w, r.WithContext(ctx)) +} + +func WithBackendFromRequestMiddleware(input http.Handler) http.Handler { + return &BackendFromRequestMiddleware{h: input} +} diff --git a/polyculeconnect/internal/middlewares/test.go b/polyculeconnect/internal/middlewares/test.go deleted file mode 100644 index 13e115a..0000000 --- a/polyculeconnect/internal/middlewares/test.go +++ /dev/null @@ -1,45 +0,0 @@ -package middlewares - -import ( - "context" - "net/http" -) - -const ( - backendNameQueryParam = "connector_id" - backendCtxKeyName = "backendName" -) - -type BackendFromRequestMiddleware struct { - h http.Handler -} - -func (m *BackendFromRequestMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/authorize" { - m.h.ServeHTTP(w, r) - return - } - - if err := r.ParseForm(); err != nil { - // TODO: handle this better - w.WriteHeader(http.StatusBadRequest) - return - } - - backendName := r.Form.Get(backendNameQueryParam) - // TODO this should be explicitly handled - if backendName == "" { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("no backend id provided")) - return - } - - // TODO we should test that this backend actually exists here - - ctx := context.WithValue(r.Context(), backendCtxKeyName, backendName) - m.h.ServeHTTP(w, r.WithContext(ctx)) -} - -func WithBackendFromRequestMiddleware(input http.Handler) http.Handler { - return &BackendFromRequestMiddleware{h: input} -} diff --git a/polyculeconnect/internal/storage/storage.go b/polyculeconnect/internal/storage/storage.go index dbe4777..0da35a2 100644 --- a/polyculeconnect/internal/storage/storage.go +++ b/polyculeconnect/internal/storage/storage.go @@ -126,12 +126,11 @@ func (s *Storage) SaveAuthCode(ctx context.Context, id string, code string) erro } func (s *Storage) DeleteAuthRequest(ctx context.Context, id string) error { - return nil // don't delete it for now, it seems we might need it????? (cc dex) - // reqID, err := uuid.Parse(id) - // if err != nil { - // return fmt.Errorf("invalid id format: %w", err) - // } - // return s.LocalStorage.AuthRequestStorage().DeleteAuthRequest(ctx, reqID) + reqID, err := uuid.Parse(id) + if err != nil { + return fmt.Errorf("invalid id format: %w", err) + } + return s.LocalStorage.AuthRequestStorage().DeleteAuthRequest(ctx, reqID) } func (s *Storage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (accessTokenID string, expiration time.Time, err error) {