package main import ( "context" "crypto/sha1" "encoding/hex" "flag" "fmt" "io" "os" "regexp" "strings" "sync" "time" "code.hackerspace.pl/hscloud/go/mirko" "github.com/golang/glog" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "code.hackerspace.pl/hscloud/games/factorio/modproxy/modportal" pb "code.hackerspace.pl/hscloud/games/factorio/modproxy/proto" ) func init() { flag.Set("logtostderr", "true") } var ( flagCASDirectory string ) func main() { flag.StringVar(&flagCASDirectory, "cas_directory", "cas", "directory in which to store cached files") flag.Parse() m := mirko.New() if err := m.Listen(); err != nil { glog.Exitf("Listen(): %v", err) } srv := &service{ cache: make(map[string]*cacheEntry), } pb.RegisterModProxyServer(m.GRPC(), srv) if err := m.Serve(); err != nil { glog.Exitf("Serve(): %v", err) } <-m.Done() } var ( reSha1 = regexp.MustCompile(`^[a-f0-9]+$`) ) func casPath(sha1 string) string { sha1 = strings.ToLower(sha1) if !reSha1.MatchString(sha1) { return "" } return fmt.Sprintf("%s/%s", flagCASDirectory, sha1) } type service struct { mu sync.Mutex // cache of sha1 -> cache entry cache map[string]*cacheEntry } type cacheEntry struct { expires *time.Time modName string // found means that this is an entry confirmed on the mod portal found bool // mirrored means we are ready to serve this file to users mirrored bool } func (s *service) Mirror(ctx context.Context, req *pb.MirrorRequest) (*pb.MirrorResponse, error) { // build map of sha1->modName for needed downloads modNames := make(map[string]string) s.mu.Lock() for sha, e := range s.cache { if e == nil { continue } if e.found == false { continue } if e.mirrored == true { continue } modNames[sha] = e.modName } s.mu.Unlock() okays := make(map[string]bool) errors := make(map[string]error) for sha, modName := range modNames { k := fmt.Sprintf("%s/%s", modName, sha) mod, err := modportal.GetMod(ctx, modName) if err != nil { errors[k] = err continue } release := mod.ReleaseBySHA1(sha) if release == nil { errors[k] = fmt.Errorf("could not find sha1 in modportal - deleted?") continue } r, err := release.Download(ctx, req.Username, req.Token) if err != nil { errors[k] = fmt.Errorf("could not download: %v", err) continue } path := casPath(sha) pathIncoming := path + ".incoming" out, err := os.Create(pathIncoming) if err != nil { errors[k] = fmt.Errorf("could not create file: %v", err) continue } _, err = io.Copy(out, r) if err != nil { errors[k] = fmt.Errorf("could not save: %v", err) continue } err = os.Rename(pathIncoming, path) if err != nil { errors[k] = fmt.Errorf("could not commit file: %v", err) continue } okays[k] = true s.cacheFeed(sha, modName, nil, true, true) } res := &pb.MirrorResponse{ ModsErrors: make(map[string]string), } for m, _ := range okays { glog.Infof("Downloaded %q", m) res.ModsOkay = append(res.ModsOkay, m) } for m, err := range errors { glog.Errorf("Could not download %q: %v", m, err) res.ModsErrors[m] = fmt.Sprintf("%v", err) } return res, nil } func (s *service) cacheGet(sha1 string) (hit, found, mirrored bool) { s.mu.Lock() defer s.mu.Unlock() entry, ok := s.cache[sha1] if !ok || entry == nil { return } if entry.expires != nil && time.Now().Before(*entry.expires) { delete(s.cache, sha1) return } hit = true found = entry.found mirrored = entry.mirrored return } func (s *service) cacheFeed(sha1, modName string, expires *time.Time, found, mirrored bool) { s.mu.Lock() defer s.mu.Unlock() s.cache[sha1] = &cacheEntry{ expires: expires, modName: modName, found: found, mirrored: mirrored, } } func (s *service) serve(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error { cas := casPath(req.FileSha1) if cas == "" { // Invalid sha1? Fail. return status.Error(codes.Aborted, "invalid sha1") } file, err := os.Open(cas) if err != nil { // not in CAS, update cache and fail s.cacheFeed(req.FileSha1, req.ModName, nil, true, false) return srv.Send(&pb.DownloadResponse{ Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE, }) } defer file.Close() err = srv.Send(&pb.DownloadResponse{ Status: pb.DownloadResponse_STATUS_OKAY, }) if err != nil { return err } buf := make([]byte, 1024*1024) hash := sha1.New() for { n, err := file.Read(buf) if err == io.EOF { break } if err != nil { return status.Errorf(codes.Unavailable, "error reading file: %v", err) } hash.Write(buf[:n]) err = srv.Send(&pb.DownloadResponse{ Chunk: buf[:n], }) if err != nil { return err } } // entire file send, double-check shasum sum := hex.EncodeToString(hash.Sum(nil)) if sum != req.FileSha1 { glog.Errorf("CAS corruption: wanted %q, got %q", req.FileSha1, sum) return status.Error(codes.Aborted, "CAS corruption") } return nil } func (s *service) Download(req *pb.DownloadRequest, srv pb.ModProxy_DownloadServer) error { ctx := srv.Context() modName := req.ModName if modName == "" { return status.Error(codes.InvalidArgument, "mod name must be set") } sha1 := req.FileSha1 if sha1 == "" { return status.Error(codes.InvalidArgument, "sha1 must be set") } sha1 = strings.ToLower(sha1) req.FileSha1 = sha1 cacheHit, found, mirrored := s.cacheGet(sha1) if cacheHit { if !found { return status.Error(codes.NotFound, "sha1 not found for mod") } if !mirrored { return srv.Send(&pb.DownloadResponse{ Status: pb.DownloadResponse_STATUS_NOT_AVAILABLE, }) } // we have the file, serve it return s.serve(req, srv) } // cache not hit, check mod portal mod, err := modportal.GetMod(ctx, modName) if err != nil { return err } release := mod.ReleaseBySHA1(sha1) // release not found in mod portal, cache and answer if release == nil { expires := time.Now().Add(1 * time.Minute) s.cacheFeed(sha1, modName, &expires, false, false) return status.Error(codes.InvalidArgument, "sha1 not found for mod") } // we assume it's mirrored - the first cas serve will prove us wrong otherwise and // update the cache. s.cacheFeed(sha1, modName, nil, true, true) // call ourselves again now that the cache is fed. computers - it's like magic! return s.Download(req, srv) }