all repos — homestead @ 96b51eedabca67efc1a2a80478f395b353da5570

Code for my website

shared/storage/sqlite/writer.go (view raw)

package sqlite

import (
	"context"
	"database/sql"
	_ "embed"
	"encoding/json"
	"fmt"
	"hash/fnv"
	"io"
	"mime"
	"net/http"
	"path/filepath"
	"time"

	"alin.ovh/homestead/domain/content"
	"alin.ovh/homestead/shared/buffer"
	"alin.ovh/homestead/shared/storage"
	"alin.ovh/homestead/shared/storage/sqlite/db"
	"alin.ovh/x/log"
	"github.com/andybalholm/brotli"
	"github.com/klauspost/compress/gzip"
	"github.com/klauspost/compress/zstd"

	"github.com/Southclaws/fault"
	"github.com/Southclaws/fault/fmsg"
	_ "modernc.org/sqlite" // import registers db/SQL driver
)

type Writer struct {
	options *Options
	log     *log.Logger
	queries *db.Queries
}

type Options struct {
	Compress bool
}

var (
	encodings = []string{"gzip", "br", "zstd"}
	//go:embed schema.sql
	schema string
)

func OpenDB(dbPath string) (*sql.DB, error) {
	db, err := sql.Open(
		"sqlite",
		fmt.Sprintf(
			"file:%s?mode=%s&_pragma=foreign_keys(1)&_pragma=mmap_size(%d)",
			dbPath,
			"rwc",
			16*1024*1024,
		),
	)
	if err != nil {
		return nil, fault.Wrap(err)
	}

	return db, nil
}

func NewWriter(conn *sql.DB, logger *log.Logger, opts *Options) (*Writer, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
	defer cancel()
	_, err := conn.ExecContext(ctx, schema)
	if err != nil {
		return nil, fault.Wrap(err, fmsg.With("creating tables"))
	}

	w := &Writer{
		queries: db.New(conn),
		log:     logger,
		options: opts,
	}

	return w, nil
}

func (s *Writer) Mkdirp(string) error {
	return nil
}

func (s *Writer) NewFileFromPost(post *content.Post) *storage.File {
	file := &storage.File{
		Title:        post.Title,
		Path:         post.URL,
		FSPath:       pathNameToFileName(post.URL),
		LastModified: post.Date,
		Encodings:    map[string]*buffer.Buffer{},
	}

	return file
}

func (s *Writer) WritePost(post *content.Post, content *buffer.Buffer) error {
	s.log.Debug("storing post", "title", post.Title)

	return s.WriteFile(s.NewFileFromPost(post), content)
}

func (s *Writer) Write(pathname string, title string, content *buffer.Buffer) error {
	file := &storage.File{
		Title:        title,
		Path:         cleanPathName(pathname),
		FSPath:       pathname,
		LastModified: time.Now(),
		Encodings:    map[string]*buffer.Buffer{},
	}

	return s.WriteFile(file, content)
}

func (s *Writer) WriteFile(file *storage.File, content *buffer.Buffer) error {
	s.log.Debug("storing content", "pathname", file.Path)

	urlID, err := s.storeURL(file.Path)
	if err != nil {
		return fault.Wrap(err, fmsg.With("storing URL"))
	}

	if file.Encodings == nil {
		file.Encodings = map[string]*buffer.Buffer{}
	}
	file.Encodings["identity"] = content

	if file.ContentType == "" {
		file.ContentType = contentType(file.FSPath)
	}

	if file.Etag == "" {
		file.Etag, err = etag(content.Bytes())
		if err != nil {
			return fault.Wrap(err, fmsg.With("could not calculate file etag"))
		}
	}

	if err := content.SeekStart(); err != nil {
		return fault.Wrap(err, fmsg.With("seeking content start"))
	}

	err = file.CalculateStyleHash()
	if err != nil {
		return fault.Wrap(err, fmsg.With("calculating file hash"))
	}

	fileID, err := s.storeFile(urlID, file)
	if err != nil {
		return fault.Wrap(err, fmsg.With("storing file"))
	}

	err = s.storeEncoding(fileID, "identity", content.Bytes())
	if err != nil {
		return err
	}

	if s.options.Compress {
		for _, enc := range encodings {
			compressed, err := compress(enc, content)
			if err != nil {
				return fault.Wrap(err, fmsg.With("compressing file"))
			}

			err = s.storeEncoding(fileID, enc, compressed.Bytes())
			if err != nil {
				return err
			}
		}
	}

	return nil
}

func compress(encoding string, content *buffer.Buffer) (*buffer.Buffer, error) {
	var w io.WriteCloser
	compressed := new(buffer.Buffer)
	switch encoding {
	case "gzip":
		w = gzip.NewWriter(compressed)
	case "br":
		w = brotli.NewWriter(compressed)
	case "zstd":
		var err error
		w, err = zstd.NewWriter(compressed)
		if err != nil {
			return nil, fault.Wrap(err, fmsg.With("could not create zstd writer"))
		}
	}
	defer w.Close()

	if err := content.SeekStart(); err != nil {
		return nil, fault.Wrap(err, fmsg.With("seeking to start of content buffer"))
	}
	if _, err := io.Copy(w, content); err != nil {
		return nil, fault.Wrap(err, fmsg.With("compressing file"))
	}

	return compressed, nil
}
func (s *Writer) storeURL(path string) (int64, error) {
	id, err := s.queries.InsertURL(context.TODO(), path)
	if err != nil {
		return 0, fault.Wrap(err, fmsg.With(fmt.Sprintf("inserting URL %s into database", path)))
	}

	return id, nil
}

func (s *Writer) storeFile(urlID int64, file *storage.File) (int64, error) {
	if file.ContentType == "" {
		file.ContentType = http.DetectContentType(file.Encodings["identity"].Bytes())
		s.log.Warn(
			"file has no content type, sniffing",
			"path",
			file.Path,
			"sniffed",
			file.ContentType,
		)
	}
	params := db.InsertFileParams{
		UrlID:        urlID,
		ContentType:  file.ContentType,
		LastModified: file.LastModified.Unix(),
		Etag:         file.Etag,
		Title:        file.Title,
		Headers:      []byte{},
	}
	if file.Headers != nil {
		var err error
		params.Headers, err = json.Marshal(file.Headers)
		if err != nil {
			return 0, fault.Wrap(err, fmsg.With("marshalling headers to JSON"))
		}
	}
	id, err := s.queries.InsertFile(context.TODO(), params)
	if err != nil {
		return 0, fault.Wrap(err, fmsg.With("inserting file into database"))
	}

	return id, nil
}

func (s *Writer) storeEncoding(fileID int64, encoding string, data []byte) error {
	err := s.queries.InsertContent(context.TODO(), db.InsertContentParams{
		Fileid:   fileID,
		Encoding: encoding,
		Body:     data,
	})
	if err != nil {
		return fault.Wrap(
			err,
			fmsg.With(fmt.Sprintf("inserting encoding into database file_id: %d encoding: %s",
				fileID,
				encoding)),
		)
	}

	return nil
}

func etag(content []byte) (string, error) {
	hash := fnv.New64a()
	_, err := hash.Write(content)
	if err != nil {
		return "", fault.Wrap(err)
	}

	return fmt.Sprintf(`W/"%x"`, hash.Sum(nil)), nil
}

func contentType(pathname string) string {
	return mime.TypeByExtension(filepath.Ext(pathNameToFileName(pathname)))
}