diff --git a/nat/pmp.go b/nat/pmp.go new file mode 100644 index 00000000..f35ce5c0 --- /dev/null +++ b/nat/pmp.go @@ -0,0 +1,117 @@ +package nat + +import ( + "fmt" + "net" + "sync" + "time" + + "github.com/jackpal/gateway" + natpmp "github.com/jackpal/go-nat-pmp" +) + +// Compile-time check to ensure PMP implements the Traversal interface. +var _ Traversal = (*PMP)(nil) + +// PMP is a concrete implementation of the Traversal interface that uses the +// NAT-PMP technique. +type PMP struct { + client *natpmp.Client + + forwardedPortsMtx sync.Mutex + forwardedPorts map[uint16]struct{} +} + +// DiscoverPMP attempts to scan the local network for a NAT-PMP enabled device +// within the given timeout. +func DiscoverPMP(timeout time.Duration) (*PMP, error) { + // Retrieve the gateway IP address of the local network. + gatewayIP, err := gateway.DiscoverGateway() + if err != nil { + return nil, err + } + + pmp := &PMP{ + client: natpmp.NewClientWithTimeout(gatewayIP, timeout), + forwardedPorts: make(map[uint16]struct{}), + } + + // We'll then attempt to retrieve the external IP address of this + // device to ensure it is not behind multiple NATs. + if _, err := pmp.ExternalIP(); err != nil { + return nil, err + } + + return pmp, nil +} + +// ExternalIP returns the external IP address of the NAT-PMP enabled device. +func (p *PMP) ExternalIP() (net.IP, error) { + res, err := p.client.GetExternalAddress() + if err != nil { + return nil, err + } + + ip := net.IP(res.ExternalIPAddress[:]) + if isPrivateIP(ip) { + return nil, ErrMultipleNAT + } + + return ip, nil +} + +// AddPortMapping enables port forwarding for the given port. +func (p *PMP) AddPortMapping(port uint16) error { + p.forwardedPortsMtx.Lock() + defer p.forwardedPortsMtx.Unlock() + + if _, exists := p.forwardedPorts[port]; exists { + return nil + } + + _, err := p.client.AddPortMapping("tcp", int(port), int(port), 0) + if err != nil { + return err + } + + p.forwardedPorts[port] = struct{}{} + + return nil +} + +// DeletePortMapping disables port forwarding for the given port. +func (p *PMP) DeletePortMapping(port uint16) error { + p.forwardedPortsMtx.Lock() + defer p.forwardedPortsMtx.Unlock() + + if _, exists := p.forwardedPorts[port]; !exists { + return fmt.Errorf("port %d is not being forwarded", port) + } + + _, err := p.client.AddPortMapping("tcp", int(port), 0, 0) + if err != nil { + return err + } + + delete(p.forwardedPorts, port) + + return nil +} + +// ForwardedPorts returns a list of ports currently being forwarded. +func (p *PMP) ForwardedPorts() []uint16 { + p.forwardedPortsMtx.Lock() + defer p.forwardedPortsMtx.Unlock() + + ports := make([]uint16, 0, len(p.forwardedPorts)) + for port := range p.forwardedPorts { + ports = append(ports, port) + } + + return ports +} + +// Name returns the name of the specific NAT traversal technique used. +func (p *PMP) Name() string { + return "NAT-PMP" +} diff --git a/nat/traversal.go b/nat/traversal.go new file mode 100644 index 00000000..852a28cb --- /dev/null +++ b/nat/traversal.go @@ -0,0 +1,58 @@ +package nat + +import ( + "errors" + "net" +) + +var ( + // private24BitBlock contains the set of private IPv4 addresses within + // the 10.0.0.0/8 adddress space. + private24BitBlock *net.IPNet + + // private20BitBlock contains the set of private IPv4 addresses within + // the 172.16.0.0/12 address space. + private20BitBlock *net.IPNet + + // private16BitBlock contains the set of private IPv4 addresses within + // the 192.168.0.0/16 address space. + private16BitBlock *net.IPNet + + // ErrMultipleNAT is an error returned when multiple NATs have been + // detected. + ErrMultipleNAT = errors.New("multiple NATs detected") +) + +func init() { + _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8") + _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12") + _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16") +} + +// Traversal is an interface that brings together the different NAT traversal +// techniques. +type Traversal interface { + // ExternalIP returns the external IP address. + ExternalIP() (net.IP, error) + + // AddPortMapping adds a port mapping for the given port between the + // private and public addresses. + AddPortMapping(port uint16) error + + // DeletePortMapping deletes a port mapping for the given port between + // the private and public addresses. + DeletePortMapping(port uint16) error + + // ForwardedPorts returns the ports currently being forwarded using NAT + // traversal. + ForwardedPorts() []uint16 + + // Name returns the name of the specific NAT traversal technique used. + Name() string +} + +// isPrivateIP determines if the IP is private. +func isPrivateIP(ip net.IP) bool { + return private24BitBlock.Contains(ip) || + private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip) +} diff --git a/nat/upnp.go b/nat/upnp.go new file mode 100644 index 00000000..4022456a --- /dev/null +++ b/nat/upnp.go @@ -0,0 +1,112 @@ +package nat + +import ( + "context" + "fmt" + "net" + "sync" + + upnp "github.com/NebulousLabs/go-upnp" +) + +// Compile-time check to ensure UPnP implements the Traversal interface. +var _ Traversal = (*UPnP)(nil) + +// UPnP is a concrete implementation of the Traversal interface that uses the +// UPnP technique. +type UPnP struct { + device *upnp.IGD + + forwardedPortsMtx sync.Mutex + forwardedPorts map[uint16]struct{} +} + +// DiscoverUPnP scans the local network for a UPnP enabled device. +func DiscoverUPnP(ctx context.Context) (*UPnP, error) { + // Scan the local network for a UPnP-enabled device. + device, err := upnp.DiscoverCtx(ctx) + if err != nil { + return nil, err + } + + u := &UPnP{ + device: device, + forwardedPorts: make(map[uint16]struct{}), + } + + // We'll then attempt to retrieve the external IP address of this + // device to ensure it is not behind multiple NATs. + if _, err := u.ExternalIP(); err != nil { + return nil, err + } + + return u, nil +} + +// ExternalIP returns the external IP address of the UPnP enabled device. +func (u *UPnP) ExternalIP() (net.IP, error) { + ip, err := u.device.ExternalIP() + if err != nil { + return nil, err + } + + if isPrivateIP(net.ParseIP(ip)) { + return nil, ErrMultipleNAT + } + + return net.ParseIP(ip), nil +} + +// AddPortMapping enables port forwarding for the given port. +func (u *UPnP) AddPortMapping(port uint16) error { + u.forwardedPortsMtx.Lock() + defer u.forwardedPortsMtx.Unlock() + + if _, exists := u.forwardedPorts[port]; exists { + return nil + } + + if err := u.device.Forward(port, ""); err != nil { + return err + } + + u.forwardedPorts[port] = struct{}{} + + return nil +} + +// DeletePortMapping disables port forwarding for the given port. +func (u *UPnP) DeletePortMapping(port uint16) error { + u.forwardedPortsMtx.Lock() + defer u.forwardedPortsMtx.Unlock() + + if _, exists := u.forwardedPorts[port]; !exists { + return fmt.Errorf("port %d is not being forwarded", port) + } + + if err := u.device.Clear(port); err != nil { + return err + } + + delete(u.forwardedPorts, port) + + return nil +} + +// ForwardedPorts returns a list of ports currently being forwarded. +func (u *UPnP) ForwardedPorts() []uint16 { + u.forwardedPortsMtx.Lock() + defer u.forwardedPortsMtx.Unlock() + + ports := make([]uint16, 0, len(u.forwardedPorts)) + for port := range u.forwardedPorts { + ports = append(ports, port) + } + + return ports +} + +// Name returns the name of the specific NAT traversal technique used. +func (u *UPnP) Name() string { + return "UPnP" +}