all repos — homestead @ f0db45107699eb1294f6d64cbf2ddd48783f7cc0

Code for my website

domain/content/fetcher/fetcher.go (view raw)

package fetcher

import (
	"context"
	"errors"
	"fmt"
	"io"
	"io/fs"
	"math"
	"net/http"
	"os"
	"path/filepath"
	"regexp"
	"strconv"
	"strings"
	"time"

	"alin.ovh/homestead/shared/config"
	"alin.ovh/homestead/shared/events"
	"alin.ovh/x/log"
	"github.com/Southclaws/fault"
	"github.com/Southclaws/fault/fmsg"
	"github.com/google/renameio/v2"
)

type Fetcher struct {
	options *Options
	log     *log.Logger
	updater events.Listener
	current uint64
	root    *os.Root
}

type Options struct {
	Root         string
	RedisEnabled bool
	FetchURL     config.URL
	Listener     events.Listener
}

var (
	files           = []string{"config.toml"}
	archive         = "site.tar.bz2"
	numericFilename = regexp.MustCompile("[0-9]+")
	timeout         = 10 * time.Second
)

func New(log *log.Logger, options *Options) (*Fetcher, error) {
	root, err := os.OpenRoot(options.Root)
	if err != nil {
		return nil, fault.Wrap(err, fmsg.With("could not open root"))
	}

	return &Fetcher{
		log:     log,
		options: options,
		updater: options.Listener,
		root:    root,
	}, nil
}

func (f *Fetcher) CleanOldRevisions() error {
	contents, err := f.root.FS().(fs.ReadDirFS).ReadDir(".")
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not read root directory"))
	}
	for _, file := range contents {
		name := file.Name()
		if name == "current" {
			continue
		}
		if numericFilename.MatchString(name) {
			v, err := strconv.ParseUint(name, 10, 64)
			if err != nil {
				return fault.Wrap(
					err,
					fmsg.With(fmt.Sprintf("could not parse numeric filename %s", name)),
				)
			}
			if v < f.current-1 {
				err := os.RemoveAll(filepath.Join(f.options.Root, name))
				if err != nil {
					return fault.Wrap(err, fmsg.With("could not remove folder"))
				}
			}
		}
	}

	return nil
}

func (f *Fetcher) Subscribe() (<-chan string, error) {
	err := f.checkFolder()
	if err != nil {
		return nil, err
	}

	var root string
	f.current, err = f.getCurrentVersion()
	if err != nil {
		f.log.Warn("could not get current version", "error", err)
	}

	if !f.options.RedisEnabled {
		root = f.path(f.current)
	} else {
		runID, err := f.initialiseStorage()
		if err != nil {
			return nil, err
		}
		root = f.path(runID)
	}

	ch := make(chan string, 1)
	go func() {
		var err error
		var attempt uint
		for {
			err = f.connect(root, ch)
			if err == nil {
				return
			}

			next := expBackoff(attempt)
			attempt++
			f.log.Warn(
				"could not connect to update listener",
				"error",
				err,
				"attempt",
				attempt,
				"next_try",
				next,
			)

			<-time.After(next)
		}
	}()

	return ch, nil
}

func (f *Fetcher) getArtefacts(run uint64) error {
	runID := strconv.FormatUint(run, 10)
	f.log.Debug("getting artefacts", "run_id", runID)

	err := f.root.MkdirAll(runID, 0o750)
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not create directory"))
	}

	for _, file := range files {
		err := f.getFile(runID, file)
		if err != nil {
			return fault.Wrap(err, fmsg.With("could not fetch file"))
		}
	}

	err = f.getArchive(runID, archive)
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not fetch archive"))
	}

	f.current = run
	err = renameio.Symlink(runID, filepath.Join(f.root.Name(), "current"))
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not create/update symlink"))
	}

	return nil
}

func (f *Fetcher) checkFolder() error {
	contents, err := f.root.FS().(fs.ReadDirFS).ReadDir(".")
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not read root directory"))
	}
	var badFiles []string
	for _, f := range contents {
		name := f.Name()
		if name != "current" && !numericFilename.MatchString(name) {
			badFiles = append(badFiles, name)
		}
	}

	if len(badFiles) > 0 {
		return fault.Wrap(
			fault.Newf("unexpected files in root directory: %s", strings.Join(badFiles, ", ")),
		)
	}

	return nil
}

func (f *Fetcher) cleanOldRevisionsAsync() {
	go func() {
		if err := f.CleanOldRevisions(); err != nil {
			f.log.Warn("error cleaning up old revisions", "error", err)
		}
	}()
}

func (f *Fetcher) makeURL(runID, basename string) string {
	return f.options.FetchURL.JoinPath(runID, basename).String()
}

func (f *Fetcher) getFile(runID, basename string) error {
	filename := filepath.Join(runID, basename)
	url := f.makeURL(runID, basename)

	f.log.Debug("getting file", "filename", filename, "url", url)

	file, err := f.root.OpenFile(filename, os.O_CREATE|os.O_WRONLY, 0o600)
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not open file"))
	}
	defer file.Close()

	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	defer cancel()

	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not create request"))
	}
	res, err := http.DefaultClient.Do(req)
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not issue request"))
	}
	defer res.Body.Close()

	_, err = io.Copy(file, res.Body)
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not write file"))
	}

	err = file.Sync()
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not sync file"))
	}

	return nil
}

func (f *Fetcher) getArchive(runID, basename string) error {
	url := f.makeURL(runID, basename)

	f.log.Debug("getting file", "url", url)

	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	defer cancel()

	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not create request"))
	}
	res, err := http.DefaultClient.Do(req)
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not issue request"))
	}
	defer res.Body.Close()

	subRoot, err := f.root.OpenRoot(runID)
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not open root"))
	}

	return extract(bunzip(res.Body), subRoot, f.log)
}

func (f *Fetcher) getCurrentVersion() (uint64, error) {
	target, err := os.Readlink(filepath.Join(f.options.Root, "current"))
	if err != nil && errors.Is(err, fs.ErrNotExist) {
		return 0, fault.Wrap(err, fmsg.With("could not stat current link"))
	}
	f.current, err = strconv.ParseUint(target, 10, 64)
	if err != nil {
		return 0, fault.Wrap(
			err,
			fmsg.With(fmt.Sprintf("unexpected symlink target (current -> %s)", target)),
		)
	}

	return f.current, nil
}

func (f *Fetcher) initialiseStorage() (uint64, error) {
	latest, err := f.updater.GetLatestRunID()
	if err != nil {
		f.log.Warn("could not get latest run ID, using fallback", "error", err)
	}

	f.log.Debug("versions", "current", f.current, "latest", latest)
	defer f.cleanOldRevisionsAsync()

	if latest > f.current {
		err = f.getArtefacts(latest)
		if err != nil {
			return latest, fault.Wrap(err, fmsg.With("could not fetch artefacts"))
		}

		return latest, nil
	}

	return f.current, nil
}

func (f *Fetcher) connect(root string, ch chan string) error {
	updates, err := f.updater.Subscribe()
	if err != nil {
		return fault.Wrap(err, fmsg.With("could not subscribe to updates"))
	}

	go func() {
		ch <- root

		for update := range updates {
			if update.RunID == 0 {
				if f.options.RedisEnabled {
					f.log.Warn("got zero runID")

					continue
				}

				ch <- f.path(f.current)
			} else {
				err := f.getArtefacts(update.RunID)
				if err != nil {
					f.log.Warn("could not get artefacts for version", "run_id", update.RunID, "error", err)

					continue
				}

				ch <- f.path(update.RunID)
			}
		}
	}()

	return nil
}

func (f *Fetcher) path(runID uint64) string {
	return filepath.Join(f.options.Root, strconv.FormatUint(runID, 10))
}

func expBackoff(attempt uint) time.Duration {
	return time.Duration(math.Exp2(float64(attempt))) * time.Second
}