tssrv: init
1 file changed, 194 insertions(+), 0 deletions(-)
changed files
A internal/tssrv/tailscale.go
@@ -0,0 +1,194 @@ +package tssrv + +import ( + "context" + "log" + "net/http" + "net/url" + "strings" + "time" + + "tailscale.com/client/local" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/util/dnsname" +) + +type Options struct { + Development bool `conf:"-"` + AllowUnknownUsers bool `conf:"flag:allow-unknown-users"` + Hostname string + Dir string + + ReadHeaderTimeout time.Duration `conf:"default:5s"` + WriteTimeout time.Duration `conf:"default:10s"` + IdleTimeout time.Duration `conf:"default:1m"` +} + +type TailscaleServer struct { + options *Options + srv *tsnet.Server + localClient *local.Client + Server *http.Server +} + +type User struct { + Login string + IsAdmin bool +} + +type capabilities struct { + Admin bool `json:"admin"` +} + +const peerCapName = "tailscale.com/cap/golink" + +func New(options *Options) (*TailscaleServer, error) { + srv := &tsnet.Server{ + Dir: options.Dir, + Hostname: options.Hostname, + RunWebClient: true, + } + if options.Dir == "" { + options.Dir = options.Hostname + } + + if err := srv.Start(); err != nil { + return nil, err + } + + localClient, err := srv.LocalClient() + if err != nil { + return nil, err + } + + for { + upCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + status, err := srv.Up(upCtx) + if err == nil && status != nil { + break + } + } + + return &TailscaleServer{ + options: options, + srv: srv, + localClient: localClient, + }, nil +} + +func (ts *TailscaleServer) Serve(httpHandler http.Handler) error { + statusCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + status, err := ts.localClient.Status(statusCtx) + if err != nil { + return err + } + enableTLS := status.Self.HasCap(tailcfg.CapabilityHTTPS) && len(ts.srv.CertDomains()) > 0 + fqdn := strings.TrimSuffix(status.Self.DNSName, ".") + + if enableTLS { + httpsHandler := HSTS(httpHandler) + httpHandler = redirectHandler(fqdn) + + httpsListener, err := ts.srv.ListenTLS("tcp", ":443") + if err != nil { + return err + } + log.Printf("Listening on :443") + go func() { + log.Printf("Serving https://%s/ ...", fqdn) + httpsServer := &http.Server{ + Handler: httpsHandler, + ReadHeaderTimeout: ts.options.ReadHeaderTimeout, + WriteTimeout: ts.options.WriteTimeout, + IdleTimeout: ts.options.IdleTimeout, + } + if err := httpsServer.Serve(httpsListener); err != nil { + log.Fatalf("error serving https: %v", err) + } + }() + } + + httpListener, err := ts.srv.Listen("tcp", ":80") + log.Printf("Listening on :80") + if err != nil { + return err + } + log.Printf("Serving http://%s/ ...", ts.options.Hostname) + httpServer := &http.Server{ + Handler: httpHandler, + ReadHeaderTimeout: ts.options.ReadHeaderTimeout, + WriteTimeout: ts.options.WriteTimeout, + IdleTimeout: ts.options.IdleTimeout, + } + if err := httpServer.Serve(httpListener); err != nil { + return err + } + + return nil +} + +// redirectHandler returns the http.Handler for serving all plaintext HTTP +// requests. It redirects all requests to the HTTPs version of the same URL. +func redirectHandler(hostname string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u := &url.URL{ + Scheme: "https", + Host: hostname, + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, + } + http.Redirect(w, r, u.String(), http.StatusFound) + }) +} + +// HSTS wraps the provided handler and sets Strict-Transport-Security header on +// responses. It inspects the Host header to ensure we do not specify HSTS +// response on non fully qualified domain name origins. +func HSTS(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, found := r.Header["Host"] + if found { + host := host[0] + fqdn, err := dnsname.ToFQDN(host) + if err == nil { + segCount := fqdn.NumLabels() + if segCount > 1 { + w.Header().Set("Strict-Transport-Security", "max-age=31536000") + } + } + } + h.ServeHTTP(w, r) + }) +} + +// CurrentUser returns the Tailscale user associated with the request. +// In most cases, this will be the user that owns the device that made the request. +// For tagged devices, the value "tagged-devices" is returned. +// If the user can't be determined (such as requests coming through a subnet router), +// an error is returned unless the -allow-unknown-users flag is set. +func (ts *TailscaleServer) CurrentUser(r *http.Request) (User, error) { + if ts.options.Development { + return User{Login: "foo@example.com"}, nil + } + whois, err := ts.localClient.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + if ts.options.AllowUnknownUsers { + // Don't report the error if we are allowing unknown Users. + return User{}, nil + } + + return User{}, err + } + login := whois.UserProfile.LoginName + caps, _ := tailcfg.UnmarshalCapJSON[capabilities](whois.CapMap, peerCapName) + for _, cap := range caps { + if cap.Admin { + return User{Login: login, IsAdmin: true}, nil + } + } + + return User{Login: login}, nil +}