extract webfinger and oidc code
1 file changed, 87 insertions(+), 0 deletions(-)
changed files
A domain/identity/webfinger/service.go
@@ -0,0 +1,87 @@ +package webfinger + +import ( + "encoding/json" + "net/http" + + ihttp "alin.ovh/homestead/shared/http" + "alin.ovh/x/log" + + "github.com/benpate/digit" +) + +type ResourceProvider interface { + GetResource() string + GetIdentityResource() digit.Resource +} + +type Service struct { + log *log.Logger + providers []ResourceProvider + corsOrigin string +} + +type Option func(*Service) + +func WithCORSOrigin(origin string) Option { + return func(s *Service) { + s.corsOrigin = origin + } +} + +func New(logger *log.Logger, providers []ResourceProvider, opts ...Option) *Service { + service := &Service{ + log: logger, + providers: providers, + corsOrigin: "*", // Default to allow all origins + } + + for _, opt := range opts { + opt(service) + } + + return service +} + +func (s *Service) RegisterHandlers(mux *http.ServeMux) { + mux.HandleFunc("/.well-known/webfinger", s.handleWebFinger) +} + +func (s *Service) Handler() http.HandlerFunc { + return s.handleWebFinger +} + +func (s *Service) HandleFunc(w http.ResponseWriter, r *http.Request) *ihttp.Error { + resource := r.URL.Query().Get("resource") + if resource == "" { + return ihttp.BadRequest("Missing resource parameter", nil) + } + + for _, provider := range s.providers { + if resource == provider.GetResource() { + w.Header().Add("Content-Type", "application/jrd+json") + + if s.corsOrigin != "" { + w.Header().Add("Access-Control-Allow-Origin", s.corsOrigin) + } + + if err := json.NewEncoder(w).Encode(provider.GetIdentityResource()); err != nil { + return ihttp.InternalServerError("Failed to encode webfinger response", err) + } + + return nil + } + } + + return ihttp.NotFound("Resource not found") +} + +func (s *Service) handleWebFinger(w http.ResponseWriter, r *http.Request) { + if err := s.HandleFunc(w, r); err != nil { + status := err.Code + if status == 0 { + status = http.StatusInternalServerError + } + http.Error(w, err.Error(), status) + } +}