forked from hswaw/hscloud
191 lines
4.0 KiB
Go
191 lines
4.0 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"regexp"
|
|
|
|
"code.hackerspace.pl/hscloud/go/mirko"
|
|
"github.com/dgraph-io/ristretto"
|
|
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api"
|
|
"github.com/golang/glog"
|
|
"github.com/ulule/limiter/v3"
|
|
"github.com/ulule/limiter/v3/drivers/store/memory"
|
|
)
|
|
|
|
func init() {
|
|
flag.Set("logtostderr", "true")
|
|
}
|
|
|
|
var (
|
|
flagPublicListen string
|
|
flagTelegramToken string
|
|
reTelegram = regexp.MustCompile(`/fileid/([a-zA-Z0-9_-]+).([a-z0-9]+)`)
|
|
)
|
|
|
|
type server struct {
|
|
cache *ristretto.Cache
|
|
limiter *limiter.Limiter
|
|
tel *tgbotapi.BotAPI
|
|
}
|
|
|
|
func main() {
|
|
flag.StringVar(&flagPublicListen, "public_listen", "127.0.0.1:5000", "Listen address for public HTTP handler")
|
|
flag.StringVar(&flagTelegramToken, "telegram_token", "", "Telegram Bot API Token")
|
|
flag.Parse()
|
|
|
|
if flagTelegramToken == "" {
|
|
glog.Exitf("telegram_token must be set")
|
|
}
|
|
|
|
cache, err := ristretto.NewCache(&ristretto.Config{
|
|
NumCounters: 1e7, // number of keys to track frequency of (10M).
|
|
MaxCost: 1 << 30, // maximum cost of cache (1GB).
|
|
BufferItems: 64, // number of keys per Get buffer.
|
|
})
|
|
if err != nil {
|
|
glog.Exit(err)
|
|
}
|
|
|
|
tel, err := tgbotapi.NewBotAPI(flagTelegramToken)
|
|
if err != nil {
|
|
glog.Exitf("Error when creating telegram bot: %v", err)
|
|
}
|
|
|
|
rate, err := limiter.NewRateFromFormatted("10-M")
|
|
if err != nil {
|
|
glog.Exit(err)
|
|
}
|
|
|
|
store := memory.NewStore()
|
|
instance := limiter.New(store, rate, limiter.WithTrustForwardHeader(true))
|
|
|
|
s := &server{
|
|
cache: cache,
|
|
limiter: instance,
|
|
tel: tel,
|
|
}
|
|
|
|
m := mirko.New()
|
|
if err := m.Listen(); err != nil {
|
|
glog.Exitf("Listen(): %v", err)
|
|
}
|
|
|
|
if err := m.Serve(); err != nil {
|
|
glog.Exitf("Serve(): %v", err)
|
|
}
|
|
|
|
publicMux := http.NewServeMux()
|
|
publicMux.HandleFunc("/", s.publicHandler)
|
|
publicSrv := http.Server{
|
|
Addr: flagPublicListen,
|
|
Handler: publicMux,
|
|
}
|
|
go func() {
|
|
if err := publicSrv.ListenAndServe(); err != nil {
|
|
glog.Exitf("public ListenAndServe: %v", err)
|
|
}
|
|
}()
|
|
|
|
<-m.Done()
|
|
}
|
|
|
|
func setMime(w http.ResponseWriter, ext string) {
|
|
switch ext {
|
|
case "jpg":
|
|
w.Header().Set("Content-Type", "image/jpeg")
|
|
case "mp4":
|
|
w.Header().Set("Content-Type", "video/mp4")
|
|
}
|
|
}
|
|
|
|
func (s *server) publicHandler(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
if !reTelegram.MatchString(r.URL.Path) {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
parts := reTelegram.FindStringSubmatch(r.URL.Path)
|
|
fileid := parts[1]
|
|
fileext := parts[2]
|
|
glog.Infof("FileID: %s", fileid)
|
|
|
|
c, ok := s.cache.Get(fileid)
|
|
if ok {
|
|
glog.Infof("Get %q - cache hit", fileid)
|
|
// cache hit
|
|
setMime(w, fileext)
|
|
w.Write(c.([]byte))
|
|
return
|
|
}
|
|
|
|
glog.Infof("Get %q - cache miss", fileid)
|
|
|
|
limit, err := s.limiter.Get(ctx, s.limiter.GetIPKey(r))
|
|
if err != nil {
|
|
w.WriteHeader(500)
|
|
fmt.Fprintf(w, ":(")
|
|
glog.Errorf("limiter.Get(%q): %v", s.limiter.GetIPKey(r), err)
|
|
return
|
|
}
|
|
|
|
if limit.Reached {
|
|
w.WriteHeader(420)
|
|
fmt.Fprintf(w, "enhance your calm")
|
|
glog.Warningf("Limit reached by %q", s.limiter.GetIPKey(r))
|
|
return
|
|
}
|
|
|
|
f, err := s.tel.GetFile(tgbotapi.FileConfig{fileid})
|
|
if err != nil {
|
|
w.WriteHeader(502)
|
|
fmt.Fprintf(w, "telegram mumbles.")
|
|
glog.Errorf("tel.GetFile(%q): %v", fileid, err)
|
|
return
|
|
}
|
|
|
|
target := f.Link(flagTelegramToken)
|
|
|
|
req, err := http.NewRequest("GET", target, nil)
|
|
if err != nil {
|
|
w.WriteHeader(500)
|
|
fmt.Fprintf(w, ":(")
|
|
glog.Errorf("NewRequest(GET, %q, nil): %v", target, err)
|
|
return
|
|
}
|
|
|
|
req = req.WithContext(ctx)
|
|
res, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
w.WriteHeader(500)
|
|
fmt.Fprintf(w, ":(")
|
|
glog.Errorf("GET(%q): %v", target, err)
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
if res.StatusCode != 200 {
|
|
// do not cache errors
|
|
w.WriteHeader(res.StatusCode)
|
|
io.Copy(w, res.Body)
|
|
return
|
|
}
|
|
|
|
b, err := ioutil.ReadAll(res.Body)
|
|
if err != nil {
|
|
w.WriteHeader(500)
|
|
fmt.Fprintf(w, ":(")
|
|
glog.Errorf("Read(%q): %v", target, err)
|
|
return
|
|
}
|
|
|
|
s.cache.Set(fileid, b, int64(len(b)))
|
|
|
|
setMime(w, fileext)
|
|
w.Write(b)
|
|
}
|