cluster/identd/kubenat: implement

This is a library to find pod information for a given TCP 4-tuple.

Change-Id: I254983e579e3aaa04c0c5491851f4af94a3f4249
This commit is contained in:
q3k 2021-05-24 15:09:25 +02:00 committed by q3k
parent ae052f0804
commit 6b649f8234
7 changed files with 896 additions and 0 deletions

View file

@ -0,0 +1,34 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = [
"kubenat.go",
"pods.go",
"translation.go",
],
importpath = "code.hackerspace.pl/hscloud/cluster/identd/kubenat",
visibility = ["//visibility:public"],
deps = [
"//cluster/identd/cri:go_default_library",
"@com_github_cenkalti_backoff//:go_default_library",
"@com_github_golang_glog//:go_default_library",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes:go_default_library",
"@org_golang_google_grpc//status:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = [
"kubenat_test.go",
"pods_test.go",
"translation_test.go",
],
embed = [":go_default_library"],
deps = [
"@com_github_go_test_deep//:go_default_library",
"@com_github_golang_glog//:go_default_library",
],
)

View file

@ -0,0 +1,130 @@
// kubenat implements a data source for undoing NAT on hosts running
// Kubernetes/containerd workloads.
//
// It parses the kernel conntrack NAT translation table to figure out the IP
// address of the pod that was making the connection.
//
// It then uses the containerd API to figure out what pod runs under what IP
// address.
//
// Both conntrack and containerd access is cached and only updated when needed.
// This means that as long as a TCP connection is open, identd will be able to
// respond about its information without having to perform any OS/containerd
// queries.
//
// Unfortunately, there is very little in terms of development/test harnesses
// for kubenat. You will have to have a locally running containerd, or do some
// mounts/forwards from a remote host.
package kubenat
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/cenkalti/backoff"
"github.com/golang/glog"
)
// Resolver is the main interface for kubenat. It runs background processing to
// update conntrack/containerd state, and resolves Tuple4s into PodInfo.
type Resolver struct {
conntrackPath string
criPath string
translationC chan *translationReq
podInfoC chan *podInfoReq
}
// Tuple4 is a 4-tuple of a TCP connection. Local describes the machine running
// this code, not the listen/connect 'ends' of TCP.
type Tuple4 struct {
RemoteIP net.IP
RemotePort uint16
LocalIP net.IP
LocalPort uint16
}
func (t *Tuple4) String() string {
local := net.JoinHostPort(t.LocalIP.String(), fmt.Sprintf("%d", t.LocalPort))
remote := net.JoinHostPort(t.RemoteIP.String(), fmt.Sprintf("%d", t.RemotePort))
return fmt.Sprintf("L: %s R: %s", local, remote)
}
// PodInfo describes a Kubernetes pod which terminates a given Tuple4 connection.
type PodInfo struct {
// PodIP is the IP address of the pod within the pod network.
PodIP net.IP
// PodTranslatedPort is the port on the PodIP corresponding to the Tuple4
// that this PodInfo was requested for.
PodTranslatedPort uint16
// KubernetesNamespace is the kubernetes namespace in which this pod is
// running.
KubernetesNamespace string
// Name is the name of the pod, as seen by kubernetes.
Name string
}
// NewResolver startss a resolver with a given path to /paroc/net/nf_conntrack
// and a CRI gRPC domain socket.
func NewResolver(ctx context.Context, conntrackPath, criPath string) (*Resolver, error) {
r := Resolver{
conntrackPath: conntrackPath,
criPath: criPath,
translationC: make(chan *translationReq),
podInfoC: make(chan *podInfoReq),
}
// TODO(q3k): bubble up errors from the translation worker into here?
go r.runTranslationWorker(ctx)
// The pod worker might fail on CRI connectivity issues, so we attempt to
// restart it with a backoff if needed.
go func() {
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
bo.Reset()
for {
err := r.runPodWorker(ctx)
if err == nil || errors.Is(err, ctx.Err()) {
glog.Infof("podWorker exiting")
return
}
glog.Errorf("podWorker failed: %v", err)
wait := bo.NextBackOff()
glog.Errorf("restarting podWorker in %v", wait)
time.Sleep(wait)
}
}()
return &r, nil
}
// ResolvePod returns information about a running pod for a given TCP 4-tuple.
// If the 4-tuple or pod cannot be resolved, an error will be returned.
func (r *Resolver) ResolvePod(ctx context.Context, t *Tuple4) (*PodInfo, error) {
// TODO(q3k): expose translation/pod not found errors as package-level
// vars, or use gRPC statuses?
podAddr, err := r.translate(ctx, t)
if err != nil {
return nil, fmt.Errorf("translate: %w", err)
}
if podAddr == nil {
return nil, fmt.Errorf("translation not found")
}
podInfo, err := r.getPodInfo(ctx, podAddr.localIP)
if err != nil {
return nil, fmt.Errorf("getPodInfo: %w", err)
}
if podInfo == nil {
return nil, fmt.Errorf("pod not found")
}
return &PodInfo{
PodIP: podAddr.localIP,
PodTranslatedPort: podAddr.localPort,
KubernetesNamespace: podInfo.namespace,
Name: podInfo.name,
}, nil
}

View file

@ -0,0 +1,43 @@
package kubenat
import (
"context"
"flag"
"net"
"testing"
)
func TestResolvePod(t *testing.T) {
t.Skip("needs containerd running on host and unhardcoded test data")
flag.Set("logtostderr", "true")
ctx, ctxC := context.WithCancel(context.Background())
defer ctxC()
r, err := NewResolver(ctx, "/tmp/conntrack", "/tmp/containerd.sock")
if err != nil {
t.Fatalf("NewResolver: %v", err)
}
pi, err := r.ResolvePod(ctx, &Tuple4{
RemoteIP: net.IPv4(185, 191, 225, 10),
RemotePort: 6697,
LocalIP: net.IPv4(185, 236, 240, 36),
LocalPort: 53449,
})
if err != nil {
t.Fatalf("ResolvePod: %v", err)
}
if want, got := net.IPv4(10, 10, 26, 23), pi.PodIP; !want.Equal(got) {
t.Errorf("Wanted pod IP %v, got %v", want, got)
}
if want, got := uint16(54782), pi.PodTranslatedPort; want != got {
t.Errorf("Wanted pod port %d, got %d", want, got)
}
if want, got := "matrix", pi.KubernetesNamespace; want != got {
t.Errorf("Wanted pod namespace %q, got %q", want, got)
}
if want, got := "appservice-irc-freenode-68977cdd5f-kfzl6", pi.Name; want != got {
t.Errorf("Wanted pod name %q, got %q", want, got)
}
}

View file

@ -0,0 +1,176 @@
package kubenat
import (
"context"
"fmt"
"net"
"github.com/golang/glog"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"code.hackerspace.pl/hscloud/cluster/identd/cri"
)
// podInfoReq is a request passed to the podWorker.
type podInfoReq struct {
local net.IP
res chan *podInfoResp
}
// podInfoResp is a response from a podWorker, sent over the res channel in a
// podInfoReq.
type podInfoResp struct {
name string
namespace string
}
// reply sends a reply to the given podInfoReq based on a CRI PodSandboxStatus,
// sending nil if the status is nil.
func (r *podInfoReq) reply(s *cri.PodSandboxStatus) {
if s == nil {
r.res <- nil
return
}
r.res <- &podInfoResp{
name: s.Metadata.Name,
namespace: s.Metadata.Namespace,
}
}
// getPodInfo performs a podInfoReq/podInfoResp exchange under a context that
// can be used to time out the query.
func (r *Resolver) getPodInfo(ctx context.Context, local net.IP) (*podInfoResp, error) {
resC := make(chan *podInfoResp, 1)
r.podInfoC <- &podInfoReq{
local: local,
res: resC,
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case res := <-resC:
return res, nil
}
}
// podStatus is a cache of data retrieved from CRI.
type podStatus struct {
// info is a map from pod sandbox ID to PodSandboxStatus as retrieved from
// CRI.
info map[string]*cri.PodSandboxStatus
// byIP is a map from pod IP (as string) to pod sandbox ID.
byIP map[string]string
}
// update performs an update of the podStatus from CRI. It only retrieves
// information about pods that it doesn't yet have, and ensures that pods which
// do not exist in CRI are also removed from podStatus.
// TODO(q3k): make sure we don't cache PodSandboxStatus too early, eg. when
// it's not yet fully running?
func (p *podStatus) update(ctx context.Context, client cri.RuntimeServiceClient) error {
res, err := client.ListPodSandbox(ctx, &cri.ListPodSandboxRequest{})
if err != nil {
return fmt.Errorf("ListPodSandbox: %w", err)
}
// set of all pod sandbox IDs in CRI.
want := make(map[string]bool)
// set of pod sandbox IDs in CRI that are not in podStatus.
missing := make(map[string]bool)
for _, item := range res.Items {
want[item.Id] = true
if _, ok := p.info[item.Id]; ok {
continue
}
missing[item.Id] = true
}
// Get information about missing pod IDs into podStatus.
for id, _ := range missing {
res, err := client.PodSandboxStatus(ctx, &cri.PodSandboxStatusRequest{
PodSandboxId: id,
})
if err != nil {
if st, ok := status.FromError(err); ok && st.Code() == codes.NotFound {
continue
} else {
return fmt.Errorf("while getting sandbox %s: %v", id, err)
}
}
p.info[id] = res.Status
}
// byIP is fully repopulated on each update.
p.byIP = make(map[string]string)
// remove is the set of pods sandbox IDs that should be removed from podStatus.
remove := make(map[string]bool)
// Populate remove and p.byId in a single pass.
for id, info := range p.info {
if _, ok := want[id]; !ok {
remove[id] = true
continue
}
if info.Network == nil {
continue
}
if info.Network.Ip == "" {
continue
}
p.byIP[info.Network.Ip] = id
}
// Remove stale pod sandbox IDs from podStatus.
for id, _ := range remove {
delete(p.info, id)
}
return nil
}
// findByPodID returns a PodSandboxStatus for the pod running under a given pod
// IP address, or nil if not found.
func (p *podStatus) findByPodIP(ip net.IP) *cri.PodSandboxStatus {
id, ok := p.byIP[ip.String()]
if !ok {
return nil
}
return p.info[id]
}
// runPodWorker runs the CRI cache 'pod worker'. It responds to requests over
// podInfoC until ctx is canceled.
func (r *Resolver) runPodWorker(ctx context.Context) error {
conn, err := grpc.Dial(fmt.Sprintf("unix://%s", r.criPath), grpc.WithInsecure())
if err != nil {
return fmt.Errorf("Dial: %w", err)
}
defer conn.Close()
client := cri.NewRuntimeServiceClient(conn)
ps := &podStatus{
info: make(map[string]*cri.PodSandboxStatus),
}
if err := ps.update(ctx, client); err != nil {
return fmt.Errorf("initial pod update: %w", err)
}
for {
select {
case req := <-r.podInfoC:
info := ps.findByPodIP(req.local)
if info != nil {
req.reply(info)
continue
}
err := ps.update(ctx, client)
if err != nil {
glog.Errorf("Updating pods failed: %v", err)
continue
}
req.reply(ps.findByPodIP(req.local))
case <-ctx.Done():
return ctx.Err()
}
}
}

View file

@ -0,0 +1,42 @@
package kubenat
import (
"context"
"flag"
"net"
"testing"
"github.com/golang/glog"
)
func TestPodWorker(t *testing.T) {
t.Skip("needs containerd running on host and unhardcoded test data")
flag.Set("logtostderr", "true")
r := &Resolver{
criPath: "/tmp/containerd.sock",
podInfoC: make(chan *podInfoReq),
}
ctx, ctxC := context.WithCancel(context.Background())
defer ctxC()
go func() {
err := r.runPodWorker(ctx)
if err != nil && err != ctx.Err() {
glog.Errorf("runPodWorker: %v", err)
}
}()
res, err := r.getPodInfo(ctx, net.IPv4(10, 10, 26, 23))
if err != nil {
t.Fatalf("got err: %v", err)
}
if res == nil {
t.Fatalf("got nil pod response")
}
if want, got := "matrix", res.namespace; want != got {
t.Errorf("namespace: got %q, wanted %q", want, got)
}
}

View file

@ -0,0 +1,341 @@
package kubenat
import (
"bufio"
"bytes"
"context"
"fmt"
"io/ioutil"
"net"
"strconv"
"strings"
"github.com/golang/glog"
)
// translationReq is a request passed to the translationWorker.
type translationReq struct {
t *Tuple4
res chan *translationResp
}
// translationResp is a response from the translationWorker, sent over the res
// channel in a translationReq.
type translationResp struct {
localIP net.IP
localPort uint16
}
// reply sends a reply to the given translationReq based on a conntrackEntry,
// sending nil if the entry is nil.
func (r *translationReq) reply(e *conntrackEntry) {
if e == nil {
r.res <- nil
return
}
localPort, err := strconv.ParseUint(e.request["sport"], 10, 16)
if err != nil {
r.res <- nil
return
}
r.res <- &translationResp{
localIP: net.ParseIP(e.request["src"]),
localPort: uint16(localPort),
}
}
// translate performs a translationReq/translationResp exchange under a context
// that can be used to time out the query.
func (r *Resolver) translate(ctx context.Context, t *Tuple4) (*translationResp, error) {
resC := make(chan *translationResp, 1)
r.translationC <- &translationReq{
t: t,
res: resC,
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case res := <-resC:
return res, nil
}
}
// conntrackEntry is an entry parsed from /proc/net/nf_conntrack. The format is
// not well documented, and the best resource I could find is:
// https://stackoverflow.com/questions/16034698/details-of-proc-net-ip-conntrack-and-proc-net-nf-conntrack
type conntrackEntry struct {
// networkProtocol is currently always "ipv4".
networkProtocol string
// transmissionProtocol is currently "tcp" or "udp".
transmissionProtocol string
invalidateTimeout int64
state string
// request key-value pairs. For NAT, these are entries relating to the
// connection as seen as the 'inside' of the NAT, eg. the pod-originated
// connection.
request map[string]string
// response key-value parirs. For NAT, these are entries relating to the
// connection as seen by the 'outside' of the NAT, eg. the internet.
response map[string]string
tags map[string]bool
}
// conntrackParseEntry parses a line from /proc/net/nf_conntrack into a conntrackEntry.
func conntrackParseEntry(line string) (*conntrackEntry, error) {
entry := conntrackEntry{
request: make(map[string]string),
response: make(map[string]string),
tags: make(map[string]bool),
}
fields := strings.Fields(line)
if len(fields) < 5 {
// This should never happen unless the file format drastically
// changed. Don't bother to parse the rest, error early, and let
// someone debug this.
return nil, fmt.Errorf("invalid field count: %v", fields)
}
switch fields[0] {
case "ipv4":
if fields[1] != "2" {
return nil, fmt.Errorf("ipv4 with proto number %q, wanted 2", fields[1])
}
// TODO(q3k): support IPv6 when we get it on prod.
default:
return nil, nil
}
entry.networkProtocol = fields[0]
rest := fields[5:]
switch fields[2] {
case "tcp":
if fields[3] != "6" {
return nil, fmt.Errorf("tcp with proto number %q, wanted 6", fields[3])
}
if len(fields) < 6 {
return nil, fmt.Errorf("tcp with missing state field")
}
entry.state = fields[5]
rest = fields[6:]
case "udp":
if fields[3] != "17" {
return nil, fmt.Errorf("udp with proto number %q, wanted 17", fields[3])
}
default:
return nil, nil
}
entry.transmissionProtocol = fields[2]
invalidateTimeout, err := strconv.ParseInt(fields[4], 10, 64)
if err != nil {
return nil, fmt.Errorf("unparseable timeout %q", fields[4])
}
entry.invalidateTimeout = invalidateTimeout
for _, el := range rest {
parts := strings.Split(el, "=")
switch len(parts) {
case 1:
// This is a tag.
tag := parts[0]
// Ensure the tag starts and ends with [] (eg. [ASSURED].
if !strings.HasPrefix(tag, "[") || !strings.HasSuffix(tag, "]") {
continue
}
// Strip [ and ].
tag = tag[1:]
tag = tag[:len(tag)-1]
if _, ok := entry.tags[tag]; ok {
return nil, fmt.Errorf("repeated tag %q", tag)
}
entry.tags[tag] = true
case 2:
// This is a k/v field.
k := parts[0]
v := parts[1]
if _, ok := entry.request[k]; ok {
if _, ok := entry.response[k]; ok {
return nil, fmt.Errorf("field %q encountered more than twice", k)
} else {
entry.response[k] = v
}
} else {
entry.request[k] = v
}
default:
return nil, fmt.Errorf("unparseable column %q", el)
}
}
return &entry, nil
}
// conntrackParse parses the contents of a /proc/net/nf_conntrack file into
// multiple entries. If the majority of the entries could not be parsed, an
// error is returned.
func conntrackParse(data []byte) ([]conntrackEntry, error) {
buf := bytes.NewBuffer(data)
scanner := bufio.NewScanner(buf)
var res []conntrackEntry
var errors []error
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
entry, err := conntrackParseEntry(line)
if err != nil {
glog.Errorf("Error while parsing %q: %v", line, err)
errors = append(errors, err)
} else if entry != nil {
res = append(res, *entry)
}
}
if len(errors) == 0 || len(errors) < len(res) {
return res, nil
} else {
return nil, fmt.Errorf("encountered too many errors during conntrack parse, check logs; first error: %w", errors[0])
}
}
// contrackIndex is an index into a list of conntrackEntries. It allows lookup
// by request/response k/v pairs.
type conntrackIndex struct {
entries []conntrackEntry
// byRequest is a map from key to value to list of indixes into entries.
byRequest map[string]map[string][]int
// byResponse is a map from key to value to list of indixes into entries.
byResponse map[string]map[string][]int
}
// buildIndex builds a conntrackIndex from a list of conntrackEntries.
func buildIndex(entries []conntrackEntry) *conntrackIndex {
ix := conntrackIndex{
entries: entries,
byRequest: make(map[string]map[string][]int),
byResponse: make(map[string]map[string][]int),
}
for i, entry := range ix.entries {
for k, v := range entry.request {
if _, ok := ix.byRequest[k]; !ok {
ix.byRequest[k] = make(map[string][]int)
}
ix.byRequest[k][v] = append(ix.byRequest[k][v], i)
}
for k, v := range entry.response {
if _, ok := ix.byResponse[k]; !ok {
ix.byResponse[k] = make(map[string][]int)
}
ix.byResponse[k][v] = append(ix.byResponse[k][v], i)
}
}
return &ix
}
// getByRequest returns conntrackEntries that match a given k/v pair in their
// request fields.
func (c *conntrackIndex) getByRequest(k, v string) []*conntrackEntry {
m, ok := c.byRequest[k]
if !ok {
return nil
}
ixes, ok := m[v]
if !ok {
return nil
}
res := make([]*conntrackEntry, len(ixes))
for i, ix := range ixes {
res[i] = &c.entries[ix]
}
return res
}
// getByResponse returns conntrackEntries that match a given k/v pair in their
// response fields.
func (c *conntrackIndex) getByResponse(k, v string) []*conntrackEntry {
m, ok := c.byResponse[k]
if !ok {
return nil
}
ixes, ok := m[v]
if !ok {
return nil
}
res := make([]*conntrackEntry, len(ixes))
for i, ix := range ixes {
res[i] = &c.entries[ix]
}
return res
}
// find returns a conntrackEntry corresponding to a TCP connection defined on
// the 'outside' of the NAT by a 4-tuple, or nil if no such connection is
// found.
func (c *conntrackIndex) find(t *Tuple4) *conntrackEntry {
// TODO(q3k): support IPv6
if t.RemoteIP.To4() == nil || t.LocalIP.To4() == nil {
return nil
}
entries := c.getByResponse("src", t.RemoteIP.String())
for _, entry := range entries {
if entry.transmissionProtocol != "tcp" {
continue
}
if entry.response["sport"] != fmt.Sprintf("%d", t.RemotePort) {
continue
}
if entry.response["dst"] != t.LocalIP.String() {
continue
}
if entry.response["dport"] != fmt.Sprintf("%d", t.LocalPort) {
continue
}
return entry
}
return nil
}
// runTranslationWorker runs the conntrack 'translation worker'. It responds to
// requests over translationC until ctx is canceled.
func (r *Resolver) runTranslationWorker(ctx context.Context) {
var ix *conntrackIndex
readConntrack := func() {
var entries []conntrackEntry
data, err := ioutil.ReadFile(r.conntrackPath)
if err != nil {
glog.Errorf("Failed to read conntrack file: %v", err)
} else {
entries, err = conntrackParse(data)
if err != nil {
glog.Errorf("failed to parse conntrack entries: %v", err)
}
}
ix = buildIndex(entries)
}
readConntrack()
for {
select {
case req := <-r.translationC:
entry := ix.find(req.t)
if entry != nil {
req.reply(entry)
} else {
readConntrack()
entry = ix.find(req.t)
if entry != nil {
req.reply(entry)
} else {
req.reply(nil)
}
}
case <-ctx.Done():
return
}
}
}

View file

@ -0,0 +1,130 @@
package kubenat
import (
"context"
"flag"
"io/ioutil"
"net"
"os"
"testing"
"github.com/go-test/deep"
)
// testConntrack is the anonymized content of a production host.
// The first entry is an appservice-irc connection from a pod to an IRC server.
// The second connection is an UDP connection between two pods.
// The third to last entry is not a NAT entry, but an incoming external
// connection.
// The fourth connection has a mangled/incomplete entry.
const testConntrack = `
ipv4 2 tcp 6 86384 ESTABLISHED src=10.10.26.23 dst=192.0.2.180 sport=51336 dport=6697 src=192.0.2.180 dst=185.236.240.36 sport=6697 dport=28706 [ASSURED] mark=0 zone=0 use=2
ipv4 2 udp 17 35 src=10.10.24.162 dst=10.10.26.108 sport=49347 dport=53 src=10.10.26.108 dst=10.10.24.162 sport=53 dport=49347 [ASSURED] mark=0 zone=0 use=2
ipv4 2 tcp 6 2 SYN_SENT src=198.51.100.67 dst=185.236.240.56 sport=51053 dport=3359 [UNREPLIED] src=185.236.240.56 dst=198.51.100.67 sport=3359 dport=51053 mark=0 zone=0 use=2
ipv4 2 tcp 6 2
`
// TestConntrackParse exercises the conntrack parser for all entries in testConntrack.
func TestConntrackParse(t *testing.T) {
// Last line is truncated and should be ignored.
got, err := conntrackParse([]byte(testConntrack))
if err != nil {
t.Fatalf("conntrackParse: %v", err)
}
want := []conntrackEntry{
{
"ipv4", "tcp", 86384, "ESTABLISHED",
map[string]string{
"src": "10.10.26.23", "dst": "192.0.2.180", "sport": "57640", "dport": "6697",
"mark": "0", "zone": "0", "use": "2",
},
map[string]string{
"src": "192.0.2.180", "dst": "185.236.240.36", "sport": "6697", "dport": "28706",
},
map[string]bool{
"ASSURED": true,
},
},
{
"ipv4", "udp", 35, "",
map[string]string{
"src": "10.10.24.162", "dst": "10.10.26.108", "sport": "49347", "dport": "53",
"mark": "0", "zone": "0", "use": "2",
},
map[string]string{
"src": "10.10.26.108", "dst": "10.10.24.162", "sport": "53", "dport": "49347",
},
map[string]bool{
"ASSURED": true,
},
},
{
"ipv4", "tcp", 2, "SYN_SENT",
map[string]string{
"src": "198.51.100.67", "dst": "185.236.240.56", "sport": "51053", "dport": "3359",
"mark": "0", "zone": "0", "use": "2",
},
map[string]string{
"src": "185.236.240.56", "dst": "198.51.100.67", "sport": "3359", "dport": "51053",
},
map[string]bool{
"UNREPLIED": true,
},
},
}
if diff := deep.Equal(want, got); diff != nil {
t.Error(diff)
}
ix := buildIndex(got)
if want, got := 0, len(ix.getByRequest("src", "1.2.3.4")); want != got {
t.Errorf("by request, src, 1.2.3.4 should have returned %d result, wanted %d", want, got)
}
if want, got := 1, len(ix.getByRequest("src", "10.10.26.23")); want != got {
t.Errorf("by request, src, 1.2.3.4 should have returned %d result, wanted %d", want, got)
}
if want, got := "10.10.26.23", ix.getByRequest("src", "10.10.26.23")[0].request["src"]; want != got {
t.Errorf("by request, wanted src %q, got %q", want, got)
}
if want, got := 3, len(ix.getByRequest("mark", "0")); want != got {
t.Errorf("by request, mark, 0 should have returned %d result, wanted %d", want, got)
}
}
// TestTranslationWorker exercises a translation worker with a
// testConntrack-backed conntrack file.
func TestTranslationWorker(t *testing.T) {
flag.Set("logtostderr", "true")
tmpfile, err := ioutil.TempFile("", "conntack")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write([]byte(testConntrack)); err != nil {
t.Fatal(err)
}
r := &Resolver{
conntrackPath: tmpfile.Name(),
translationC: make(chan *translationReq),
}
ctx, ctxC := context.WithCancel(context.Background())
defer ctxC()
go r.runTranslationWorker(ctx)
res, err := r.translate(ctx, &Tuple4{
RemoteIP: net.ParseIP("192.0.2.180"),
RemotePort: 6697,
LocalIP: net.ParseIP("185.236.240.36"),
LocalPort: 28706,
})
if err != nil {
t.Fatalf("translate: %v", err)
}
if want, got := net.ParseIP("10.10.26.23"), res.localIP; !want.Equal(got) {
t.Errorf("local ip: wanted %v, got %v", want, got)
}
if want, got := uint16(51336), res.localPort; want != got {
t.Errorf("local port: wanted %d, got %d", want, got)
}
}