hscloud/games/factorio/modproxy/main.go
Sergiusz Bazanski 0581bbf8a0 games/factorio: add modproxy
This adds a mod proxy system, called, well, modproxy.

It sits between Factorio server instances and the Factorio mod portal,
allowing for arbitrary mod download without needing the servers to know
Factorio credentials.

Change-Id: I7bc405a25b6f9559cae1f23295249f186761f212
2020-08-14 13:03:46 +02:00

298 lines
6.2 KiB
Go

package main
import (
"context"
"crypto/sha1"
"encoding/hex"
"flag"
"fmt"
"io"
"os"
"regexp"
"strings"
"sync"
"time"
"code.hackerspace.pl/hscloud/go/mirko"
"github.com/golang/glog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"code.hackerspace.pl/hscloud/games/factorio/modproxy/modportal"
pb "code.hackerspace.pl/hscloud/games/factorio/modproxy/proto"
)
func init() {
flag.Set("logtostderr", "true")
}
var (
flagCASDirectory string
)
func main() {
flag.StringVar(&flagCASDirectory, "cas_directory", "cas", "directory in which to store cached files")
flag.Parse()
m := mirko.New()
if err := m.Listen(); err != nil {
glog.Exitf("Listen(): %v", err)
}
srv := &service{
cache: make(map[string]*cacheEntry),
}
pb.RegisterModProxyServer(m.GRPC(), srv)
if err := m.Serve(); err != nil {
glog.Exitf("Serve(): %v", err)
}
<-m.Done()
}
var (
reSha1 = regexp.MustCompile(`^[a-f0-9]+$`)
)
func casPath(sha1 string) string {
sha1 = strings.ToLower(sha1)
if !reSha1.MatchString(sha1) {
return ""
}
return fmt.Sprintf("%s/%s", flagCASDirectory, sha1)
}
type service struct {
mu sync.Mutex
// cache of sha1 -> cache entry
cache map[string]*cacheEntry
}
type cacheEntry struct {
expires *time.Time
modName string
// found means that this is an entry confirmed on the mod portal
found bool
// mirrored means we are ready to serve this file to users
mirrored bool
}
func (s *service) Mirror(ctx context.Context, req *pb.MirrorRequest) (*pb.MirrorResponse, error) {
// build map of sha1->modName for needed downloads
modNames := make(map[string]string)
s.mu.Lock()
for sha, e := range s.cache {
if e == nil {
continue
}
if e.found == false {
continue
}
if e.mirrored == true {
continue
}
modNames[sha] = e.modName
}
s.mu.Unlock()
okays := make(map[string]bool)
errors := make(map[string]error)
for sha, modName := range modNames {
k := fmt.Sprintf("%s/%s", modName, sha)
mod, err := modportal.GetMod(ctx, modName)
if err != nil {
errors[k] = err
continue
}
release := mod.ReleaseBySHA1(sha)
if release == nil {
errors[k] = fmt.Errorf("could not find sha1 in modportal - deleted?")
continue
}
r, err := release.Download(ctx, req.Username, req.Token)
if err != nil {
errors[k] = fmt.Errorf("could not download: %v", err)
continue
}
path := casPath(sha)
pathIncoming := path + ".incoming"
out, err := os.Create(pathIncoming)
if err != nil {
errors[k] = fmt.Errorf("could not create file: %v", err)
continue
}
_, err = io.Copy(out, r)
if err != nil {
errors[k] = fmt.Errorf("could not save: %v", err)
continue
}
err = os.Rename(pathIncoming, path)
if err != nil {
errors[k] = fmt.Errorf("could not commit file: %v", err)
continue
}
okays[k] = true
s.cacheFeed(sha, modName, nil, true, true)
}
res := &pb.MirrorResponse{
ModsErrors: make(map[string]string),
}
for m, _ := range okays {
glog.Infof("Downloaded %q", m)
res.ModsOkay = append(res.ModsOkay, m)
}
for m, err := range errors {
glog.Errorf("Could not download %q: %v", m, err)
res.ModsErrors[m] = fmt.Sprintf("%v", err)
}
return res, nil
}
func (s *service) cacheGet(sha1 string) (hit, found, mirrored bool) {
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.cache[sha1]
if !ok || entry == nil {
return
}
if entry.expires != nil && time.Now().Before(*entry.expires) {
delete(s.cache, sha1)
return
}
hit = true
found = entry.found
mirrored = entry.mirrored
return
}
func (s *service) cacheFeed(sha1, modName string, expires *time.Time, found, mirrored bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.cache[sha1] = &cacheEntry{
expires: expires,
modName: modName,
found: found,
mirrored: mirrored,
}
}
func (s *service) serve(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
cas := casPath(req.FileSha1)
if cas == "" {
// Invalid sha1? Fail.
return status.Error(codes.Aborted, "invalid sha1")
}
file, err := os.Open(cas)
if err != nil {
// not in CAS, update cache and fail
s.cacheFeed(req.FileSha1, req.ModName, nil, true, false)
return srv.Send(&pb.DownloadResponse{
Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
})
}
defer file.Close()
err = srv.Send(&pb.DownloadResponse{
Status: pb.DownloadResponse_STATUS_OKAY,
})
if err != nil {
return err
}
buf := make([]byte, 1024*1024)
hash := sha1.New()
for {
n, err := file.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return status.Errorf(codes.Unavailable, "error reading file: %v", err)
}
hash.Write(buf[:n])
err = srv.Send(&pb.DownloadResponse{
Chunk: buf[:n],
})
if err != nil {
return err
}
}
// entire file send, double-check shasum
sum := hex.EncodeToString(hash.Sum(nil))
if sum != req.FileSha1 {
glog.Errorf("CAS corruption: wanted %q, got %q", req.FileSha1, sum)
return status.Error(codes.Aborted, "CAS corruption")
}
return nil
}
func (s *service) Download(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
ctx := srv.Context()
modName := req.ModName
if modName == "" {
return status.Error(codes.InvalidArgument, "mod name must be set")
}
sha1 := req.FileSha1
if sha1 == "" {
return status.Error(codes.InvalidArgument, "sha1 must be set")
}
sha1 = strings.ToLower(sha1)
req.FileSha1 = sha1
cacheHit, found, mirrored := s.cacheGet(sha1)
if cacheHit {
if !found {
return status.Error(codes.NotFound, "sha1 not found for mod")
}
if !mirrored {
return srv.Send(&pb.DownloadResponse{
Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
})
}
// we have the file, serve it
return s.serve(req, srv)
}
// cache not hit, check mod portal
mod, err := modportal.GetMod(ctx, modName)
if err != nil {
return err
}
release := mod.ReleaseBySHA1(sha1)
// release not found in mod portal, cache and answer
if release == nil {
expires := time.Now().Add(1 * time.Minute)
s.cacheFeed(sha1, modName, &expires, false, false)
return status.Error(codes.InvalidArgument, "sha1 not found for mod")
}
// we assume it's mirrored - the first cas serve will prove us wrong otherwise and
// update the cache.
s.cacheFeed(sha1, modName, nil, true, true)
// call ourselves again now that the cache is fed. computers - it's like magic!
return s.Download(req, srv)
}