From 4084cad34b68dc1ebaa2f87e60517e0d9c4a120a Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Sat, 28 Mar 2015 02:54:17 +1100 Subject: [PATCH] iplist: Fail invalid IPs, they were always passing --- iplist/iplist.go | 11 +++++++++-- iplist/iplist_test.go | 26 ++++++++++++++++++++++---- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/iplist/iplist.go b/iplist/iplist.go index 958af5a9..93fb9b31 100644 --- a/iplist/iplist.go +++ b/iplist/iplist.go @@ -45,15 +45,22 @@ func (me *IPList) Lookup(ip net.IP) (r *Range) { // TODO: Perhaps all addresses should be converted to IPv6, if the future // of IP is to always be backwards compatible. But this will cost 4x the // memory for IPv4 addresses? - if v4 := ip.To4(); v4 != nil { + v4 := ip.To4() + if v4 != nil { r = me.lookup(v4) if r != nil { return } } - if v6 := ip.To16(); v6 != nil { + v6 := ip.To16() + if v6 != nil { return me.lookup(v6) } + if v4 == nil && v6 == nil { + return &Range{ + Description: fmt.Sprintf("unsupported IP: %s", ip), + } + } return nil } diff --git a/iplist/iplist_test.go b/iplist/iplist_test.go index 09d216fc..f21ecc3d 100644 --- a/iplist/iplist_test.go +++ b/iplist/iplist_test.go @@ -73,6 +73,22 @@ func connRemoteAddrIP(network, laddr string, dialHost string) net.IP { return ret } +func TestBadIP(t *testing.T) { + iplist := New(nil) + if iplist.Lookup(net.IP(make([]byte, 4))) != nil { + t.FailNow() + } + if iplist.Lookup(net.IP(make([]byte, 16))) != nil { + t.FailNow() + } + if iplist.Lookup(nil) == nil { + t.FailNow() + } + if iplist.Lookup(net.IP(make([]byte, 5))) == nil { + t.FailNow() + } +} + func TestSimple(t *testing.T) { ranges, err := sampleRanges(t) if err != nil { @@ -90,14 +106,16 @@ func TestSimple(t *testing.T) { {"1.2.3.255", false, ""}, {"1.2.8.0", true, "b"}, {"1.2.4.255", true, "a"}, - // Try to roll over to the next octet on the parse. - {"1.2.7.256", false, ""}, + // Try to roll over to the next octet on the parse. Note the final + // octet is overbounds. In the next case. + {"1.2.7.256", true, "unsupported IP: "}, {"1.2.8.254", true, "b"}, } { - r := iplist.Lookup(net.ParseIP(_case.IP)) + ip := net.ParseIP(_case.IP) + r := iplist.Lookup(ip) if !_case.Hit { if r != nil { - t.Fatalf("got hit when none was expected") + t.Fatalf("got hit when none was expected: %s", ip) } continue }