close
The Wayback Machine - https://web.archive.org/web/20220725135416/https://github.com/tsenart/vegeta/commit/1be3214ba0ef7ea154c3ebed0f0fd017ffcf5cfc
Skip to content
Permalink
Browse files
Reworking items per tsenart
Renamed instances of address / addresses to addr / addrs

Removed dependency on google.com for test, using dnsmasq instead
if the VEGETA_TESTDNSMASQ_ENABLE environment variable is present.

Updated the travis build to set this environment variable, since
travis should actually have dnsmasq installed.

De-optimized address()
  • Loading branch information
nathanejohnson committed Aug 28, 2018
1 parent b3796f5 commit 1be3214ba0ef7ea154c3ebed0f0fd017ffcf5cfc
Showing 5 changed files with 124 additions and 70 deletions.
@@ -9,4 +9,4 @@ install:
script:
- go vet ./...
- $HOME/gopath/bin/golint .
- go test -v ./...
- VEGETA_TESTDNSMASQ_ENABLE=1 go test -v ./...
@@ -5,5 +5,5 @@ package main
import "flag"

func systemSpecificFlags(fs *flag.FlagSet, opts *attackOpts) {
fs.Var(&opts.resolvers, "resolvers", "Override system dns resolution (comma separated list)")
fs.Var(&opts.resolvers, "resolvers", "List of addresses (ip:port) to use for DNS resolution. Disables use of local system DNS. (comma separated list)")
}
@@ -0,0 +1 @@
127.0.0.1 acme.notadomain
@@ -11,74 +11,72 @@ import (
)

type resolver struct {
addresses []string
dialer *net.Dialer
idx uint64
addrs []string
dialer *net.Dialer
idx uint64
}

// NewResolver - create a new instance of a dns resolver for plugging
// into net.DefaultResolver. Addresses should be a list of
// ip addresses and optional port numbers, separated by colon.
// ip addrs and optional port numbers, separated by colon.
// For example: 1.2.3.4:53 and 1.2.3.4 are both valid. In the absence
// of a port number, 53 will be used instead.
func NewResolver(addresses []string) (*net.Resolver, error) {
normalAddresses, err := normalizeResolverAddresses(addresses)
func NewResolver(addrs []string) (*net.Resolver, error) {
if len(addrs) == 0 {
return nil, errors.New("must specify at least resolver address")
}
cleanAddrs, err := normalizeAddrs(addrs)
if err != nil {
return nil, err
}
return &net.Resolver{
PreferGo: true,
Dial: (&resolver{addresses: normalAddresses, dialer: &net.Dialer{}}).dial,
Dial: (&resolver{addrs: cleanAddrs, dialer: &net.Dialer{}}).dial,
}, nil
}

func normalizeResolverAddresses(addresses []string) ([]string, error) {
if len(addresses) == 0 {
return nil, errors.New("must specify at least resolver address")
}
normalAddresses := make([]string, len(addresses))
for i, addr := range addresses {
ipPort := strings.Split(addr, ":")
port := 53
var host string
func normalizeAddrs(addrs []string) ([]string, error) {
normal := make([]string, len(addrs))
for i, addr := range addrs {

switch len(ipPort) {
case 2:
pu16, err := strconv.ParseUint(ipPort[1], 10, 16)
if err != nil {
return nil, err
}
port = int(pu16)
fallthrough
case 1:
host = ipPort[0]
default:
return nil, fmt.Errorf("invalid ip:port specified: %s", addr)
// if addr has no port, give it 53
if !strings.Contains(addr, ":") {
addr += ":53"
}

// validate addr is a valid host:port
host, portstr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}

// validate valid port.
port, err := strconv.ParseUint(portstr, 10, 16)
if err != nil {
return nil, err
}

if port <= 0 {
return nil, errors.New("invalid port")
}

// make sure host is an ip.
ip := net.ParseIP(host)
if ip == nil {
return nil, fmt.Errorf("host %s is not an IP address", host)
}

normalAddresses[i] = fmt.Sprintf("%s:%d", host, port)
normal[i] = addr
}
return normalAddresses, nil
return normal, nil
}

// ignore the third parameter, as this represents the dns server address that
// are overriding.
// we are overriding.
func (r *resolver) dial(ctx context.Context, network, _ string) (net.Conn, error) {
return r.dialer.DialContext(ctx, network, r.address())
}

func (r *resolver) address() string {
var address string
if l := uint64(len(r.addresses)); l > 1 {
idx := atomic.AddUint64(&r.idx, 1)
address = r.addresses[idx%l]
} else {
address = r.addresses[0]
}
return address
return r.addrs[atomic.AddUint64(&r.idx, 1)%uint64(len(r.addrs))]
}
@@ -1,27 +1,101 @@
package resolver

import (
"bytes"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"os/exec"
"strings"
"testing"
"time"
)

func TestResolve8888(t *testing.T) {
r, err := NewResolver([]string{"8.8.8.8:53"})
if err != nil {
t.FailNow()
const (
dnsmasqRunEnv = "VEGETA_TESTDNSMASQ_ENABLE"
dnsmasqPathEnv = "VEGETA_TESTDNSMASQ_PATH"
dnsmasqPortEnv = "VEGETA_TESTDNSMASQ_PORT"
)

func TestResolveDNSMasq(t *testing.T) {
const payload = "there is no cloud, just someone else's computer"

var (
path = "dnsmasq"
port = "5300"
)
if _, ok := os.LookupEnv(dnsmasqRunEnv); !ok {
t.Skipf("skipping test becuase %s is not set", dnsmasqRunEnv)
}
if ePort, ok := os.LookupEnv(dnsmasqPortEnv); ok {
port = ePort
}
if ePath, ok := os.LookupEnv(dnsmasqPathEnv); ok {
path = ePath
}

net.DefaultResolver = r
stdout := bytes.Buffer{}
stderr := bytes.Buffer{}
cmd := exec.Command(path, "-h", "-H", "./hosts", "-p", port, "-d")

resp, err := http.Get("https://www.google.com/")
cmd.Stdout = &stdout
cmd.Stderr = &stderr

err := cmd.Start()
if err != nil {
t.Logf("error from http.Get(): %s", err)
t.FailNow()
t.Fatalf("failed starting dnsmasq: %s", err)
}
defer func() {
err := cmd.Wait()
if err != nil {
t.Logf("unclean shutdown of dnsmasq: %s", err)
}
t.Log(stdout.String())
t.Log(stderr.String())
}()
time.Sleep(time.Second)

defer func() {
_ = cmd.Process.Kill()
}()

res, err := NewResolver([]string{net.JoinHostPort("127.0.0.1", port)})
if err != nil {
t.Errorf("error from NewResolver: %s", err)
return
}
net.DefaultResolver = res

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, payload)
}))
defer ts.Close()

tsurl, _ := url.Parse(ts.URL)

resp.Body.Close()
_, hport, err := net.SplitHostPort(tsurl.Host)
if err != nil {
t.Errorf("could not parse port from httptest url %s: %s", ts.URL, err)
return
}
tsurl.Host = net.JoinHostPort("acme.notadomain", hport)
resp, err := http.Get(tsurl.String())
if err != nil {
t.Errorf("failed resolver round trip: %s", err)
return
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("failed to read respose body")
return
}
if strings.TrimSpace(string(body)) != payload {
t.Errorf("body mismatch, got: '%s', expected: '%s'", body, payload)
}
}

func TestResolveAddresses(t *testing.T) {
@@ -64,7 +138,7 @@ func TestResolveAddresses(t *testing.T) {
}
for subtest, tdata := range table {
t.Run(subtest, func(t *testing.T) {
addrs, err := normalizeResolverAddresses(tdata.input)
addrs, err := normalizeAddrs(tdata.input)
if tdata.expectError {
if err == nil {
t.Error("expected error, got none")
@@ -97,22 +171,3 @@ func TestResolveAddresses(t *testing.T) {
}

}

func TestResolverOverflow(t *testing.T) {
res := &resolver{
addresses: []string{"8.8.8.8:53", "9.9.9.9:53"},
idx: ^uint64(0),
}
_ = res.address()
if res.idx != 0 {
t.Error("overflow not handled gracefully")
}
// throw away another one to make sure we're back to 0
_ = res.address()
for i := 0; i < 5; i++ {
addr := res.address()
if expectedAddr := res.addresses[i%len(res.addresses)]; expectedAddr != addr {
t.Errorf("address mismatch, have: %s, want: %s", addr, expectedAddr)
}
}
}

0 comments on commit 1be3214

Please sign in to comment.