mirror of
https://gerrit.hackerspace.pl/hscloud
synced 2024-10-19 04:27:45 +00:00
299 lines
6.2 KiB
Go
299 lines
6.2 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/sha1"
|
||
|
"encoding/hex"
|
||
|
"flag"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"os"
|
||
|
"regexp"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"code.hackerspace.pl/hscloud/go/mirko"
|
||
|
"github.com/golang/glog"
|
||
|
"google.golang.org/grpc/codes"
|
||
|
"google.golang.org/grpc/status"
|
||
|
|
||
|
"code.hackerspace.pl/hscloud/games/factorio/modproxy/modportal"
|
||
|
pb "code.hackerspace.pl/hscloud/games/factorio/modproxy/proto"
|
||
|
)
|
||
|
|
||
|
func init() {
|
||
|
flag.Set("logtostderr", "true")
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
flagCASDirectory string
|
||
|
)
|
||
|
|
||
|
func main() {
|
||
|
flag.StringVar(&flagCASDirectory, "cas_directory", "cas", "directory in which to store cached files")
|
||
|
flag.Parse()
|
||
|
m := mirko.New()
|
||
|
if err := m.Listen(); err != nil {
|
||
|
glog.Exitf("Listen(): %v", err)
|
||
|
}
|
||
|
|
||
|
srv := &service{
|
||
|
cache: make(map[string]*cacheEntry),
|
||
|
}
|
||
|
|
||
|
pb.RegisterModProxyServer(m.GRPC(), srv)
|
||
|
|
||
|
if err := m.Serve(); err != nil {
|
||
|
glog.Exitf("Serve(): %v", err)
|
||
|
}
|
||
|
|
||
|
<-m.Done()
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
reSha1 = regexp.MustCompile(`^[a-f0-9]+$`)
|
||
|
)
|
||
|
|
||
|
func casPath(sha1 string) string {
|
||
|
sha1 = strings.ToLower(sha1)
|
||
|
if !reSha1.MatchString(sha1) {
|
||
|
return ""
|
||
|
}
|
||
|
return fmt.Sprintf("%s/%s", flagCASDirectory, sha1)
|
||
|
}
|
||
|
|
||
|
type service struct {
|
||
|
mu sync.Mutex
|
||
|
|
||
|
// cache of sha1 -> cache entry
|
||
|
cache map[string]*cacheEntry
|
||
|
}
|
||
|
|
||
|
type cacheEntry struct {
|
||
|
expires *time.Time
|
||
|
modName string
|
||
|
|
||
|
// found means that this is an entry confirmed on the mod portal
|
||
|
found bool
|
||
|
// mirrored means we are ready to serve this file to users
|
||
|
mirrored bool
|
||
|
}
|
||
|
|
||
|
func (s *service) Mirror(ctx context.Context, req *pb.MirrorRequest) (*pb.MirrorResponse, error) {
|
||
|
|
||
|
// build map of sha1->modName for needed downloads
|
||
|
modNames := make(map[string]string)
|
||
|
s.mu.Lock()
|
||
|
for sha, e := range s.cache {
|
||
|
if e == nil {
|
||
|
continue
|
||
|
}
|
||
|
if e.found == false {
|
||
|
continue
|
||
|
}
|
||
|
if e.mirrored == true {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
modNames[sha] = e.modName
|
||
|
}
|
||
|
s.mu.Unlock()
|
||
|
|
||
|
okays := make(map[string]bool)
|
||
|
errors := make(map[string]error)
|
||
|
|
||
|
for sha, modName := range modNames {
|
||
|
k := fmt.Sprintf("%s/%s", modName, sha)
|
||
|
mod, err := modportal.GetMod(ctx, modName)
|
||
|
if err != nil {
|
||
|
errors[k] = err
|
||
|
continue
|
||
|
}
|
||
|
release := mod.ReleaseBySHA1(sha)
|
||
|
if release == nil {
|
||
|
errors[k] = fmt.Errorf("could not find sha1 in modportal - deleted?")
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
r, err := release.Download(ctx, req.Username, req.Token)
|
||
|
if err != nil {
|
||
|
errors[k] = fmt.Errorf("could not download: %v", err)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
path := casPath(sha)
|
||
|
pathIncoming := path + ".incoming"
|
||
|
|
||
|
out, err := os.Create(pathIncoming)
|
||
|
if err != nil {
|
||
|
errors[k] = fmt.Errorf("could not create file: %v", err)
|
||
|
continue
|
||
|
}
|
||
|
_, err = io.Copy(out, r)
|
||
|
if err != nil {
|
||
|
errors[k] = fmt.Errorf("could not save: %v", err)
|
||
|
continue
|
||
|
}
|
||
|
err = os.Rename(pathIncoming, path)
|
||
|
if err != nil {
|
||
|
errors[k] = fmt.Errorf("could not commit file: %v", err)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
okays[k] = true
|
||
|
s.cacheFeed(sha, modName, nil, true, true)
|
||
|
}
|
||
|
|
||
|
res := &pb.MirrorResponse{
|
||
|
ModsErrors: make(map[string]string),
|
||
|
}
|
||
|
for m, _ := range okays {
|
||
|
glog.Infof("Downloaded %q", m)
|
||
|
res.ModsOkay = append(res.ModsOkay, m)
|
||
|
}
|
||
|
for m, err := range errors {
|
||
|
glog.Errorf("Could not download %q: %v", m, err)
|
||
|
res.ModsErrors[m] = fmt.Sprintf("%v", err)
|
||
|
}
|
||
|
|
||
|
return res, nil
|
||
|
}
|
||
|
|
||
|
func (s *service) cacheGet(sha1 string) (hit, found, mirrored bool) {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
entry, ok := s.cache[sha1]
|
||
|
if !ok || entry == nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if entry.expires != nil && time.Now().Before(*entry.expires) {
|
||
|
delete(s.cache, sha1)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
hit = true
|
||
|
found = entry.found
|
||
|
mirrored = entry.mirrored
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (s *service) cacheFeed(sha1, modName string, expires *time.Time, found, mirrored bool) {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
s.cache[sha1] = &cacheEntry{
|
||
|
expires: expires,
|
||
|
modName: modName,
|
||
|
found: found,
|
||
|
mirrored: mirrored,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *service) serve(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
|
||
|
cas := casPath(req.FileSha1)
|
||
|
if cas == "" {
|
||
|
// Invalid sha1? Fail.
|
||
|
return status.Error(codes.Aborted, "invalid sha1")
|
||
|
}
|
||
|
|
||
|
file, err := os.Open(cas)
|
||
|
if err != nil {
|
||
|
// not in CAS, update cache and fail
|
||
|
s.cacheFeed(req.FileSha1, req.ModName, nil, true, false)
|
||
|
return srv.Send(&pb.DownloadResponse{
|
||
|
Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
|
||
|
})
|
||
|
}
|
||
|
defer file.Close()
|
||
|
|
||
|
err = srv.Send(&pb.DownloadResponse{
|
||
|
Status: pb.DownloadResponse_STATUS_OKAY,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
buf := make([]byte, 1024*1024)
|
||
|
hash := sha1.New()
|
||
|
|
||
|
for {
|
||
|
n, err := file.Read(buf)
|
||
|
if err == io.EOF {
|
||
|
break
|
||
|
}
|
||
|
if err != nil {
|
||
|
return status.Errorf(codes.Unavailable, "error reading file: %v", err)
|
||
|
}
|
||
|
hash.Write(buf[:n])
|
||
|
err = srv.Send(&pb.DownloadResponse{
|
||
|
Chunk: buf[:n],
|
||
|
})
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// entire file send, double-check shasum
|
||
|
sum := hex.EncodeToString(hash.Sum(nil))
|
||
|
if sum != req.FileSha1 {
|
||
|
glog.Errorf("CAS corruption: wanted %q, got %q", req.FileSha1, sum)
|
||
|
return status.Error(codes.Aborted, "CAS corruption")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *service) Download(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error {
|
||
|
ctx := srv.Context()
|
||
|
|
||
|
modName := req.ModName
|
||
|
if modName == "" {
|
||
|
return status.Error(codes.InvalidArgument, "mod name must be set")
|
||
|
}
|
||
|
sha1 := req.FileSha1
|
||
|
if sha1 == "" {
|
||
|
return status.Error(codes.InvalidArgument, "sha1 must be set")
|
||
|
}
|
||
|
sha1 = strings.ToLower(sha1)
|
||
|
req.FileSha1 = sha1
|
||
|
|
||
|
cacheHit, found, mirrored := s.cacheGet(sha1)
|
||
|
if cacheHit {
|
||
|
if !found {
|
||
|
return status.Error(codes.NotFound, "sha1 not found for mod")
|
||
|
}
|
||
|
if !mirrored {
|
||
|
return srv.Send(&pb.DownloadResponse{
|
||
|
Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// we have the file, serve it
|
||
|
return s.serve(req, srv)
|
||
|
}
|
||
|
|
||
|
// cache not hit, check mod portal
|
||
|
mod, err := modportal.GetMod(ctx, modName)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
release := mod.ReleaseBySHA1(sha1)
|
||
|
|
||
|
// release not found in mod portal, cache and answer
|
||
|
if release == nil {
|
||
|
expires := time.Now().Add(1 * time.Minute)
|
||
|
s.cacheFeed(sha1, modName, &expires, false, false)
|
||
|
return status.Error(codes.InvalidArgument, "sha1 not found for mod")
|
||
|
}
|
||
|
|
||
|
// we assume it's mirrored - the first cas serve will prove us wrong otherwise and
|
||
|
// update the cache.
|
||
|
s.cacheFeed(sha1, modName, nil, true, true)
|
||
|
// call ourselves again now that the cache is fed. computers - it's like magic!
|
||
|
return s.Download(req, srv)
|
||
|
}
|