nat: introduce new NAT traversal package with UPnP/NAT-PMP implementations
In this commit, we introduce a new package `nat`. This package is responsible for handling the different techniques for NAT traversals. Specifically, we have implemented UPnP and NAT-PMP support. This will allow users to easily advertise their nodes to the network as long as their devices are behind a single NAT. Devices behind multiple NATs are not supported.
This commit is contained in:
parent
8651b4a422
commit
377e770db4
117
nat/pmp.go
Normal file
117
nat/pmp.go
Normal file
@ -0,0 +1,117 @@
|
||||
package nat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackpal/gateway"
|
||||
natpmp "github.com/jackpal/go-nat-pmp"
|
||||
)
|
||||
|
||||
// Compile-time check to ensure PMP implements the Traversal interface.
|
||||
var _ Traversal = (*PMP)(nil)
|
||||
|
||||
// PMP is a concrete implementation of the Traversal interface that uses the
|
||||
// NAT-PMP technique.
|
||||
type PMP struct {
|
||||
client *natpmp.Client
|
||||
|
||||
forwardedPortsMtx sync.Mutex
|
||||
forwardedPorts map[uint16]struct{}
|
||||
}
|
||||
|
||||
// DiscoverPMP attempts to scan the local network for a NAT-PMP enabled device
|
||||
// within the given timeout.
|
||||
func DiscoverPMP(timeout time.Duration) (*PMP, error) {
|
||||
// Retrieve the gateway IP address of the local network.
|
||||
gatewayIP, err := gateway.DiscoverGateway()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pmp := &PMP{
|
||||
client: natpmp.NewClientWithTimeout(gatewayIP, timeout),
|
||||
forwardedPorts: make(map[uint16]struct{}),
|
||||
}
|
||||
|
||||
// We'll then attempt to retrieve the external IP address of this
|
||||
// device to ensure it is not behind multiple NATs.
|
||||
if _, err := pmp.ExternalIP(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pmp, nil
|
||||
}
|
||||
|
||||
// ExternalIP returns the external IP address of the NAT-PMP enabled device.
|
||||
func (p *PMP) ExternalIP() (net.IP, error) {
|
||||
res, err := p.client.GetExternalAddress()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip := net.IP(res.ExternalIPAddress[:])
|
||||
if isPrivateIP(ip) {
|
||||
return nil, ErrMultipleNAT
|
||||
}
|
||||
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
// AddPortMapping enables port forwarding for the given port.
|
||||
func (p *PMP) AddPortMapping(port uint16) error {
|
||||
p.forwardedPortsMtx.Lock()
|
||||
defer p.forwardedPortsMtx.Unlock()
|
||||
|
||||
if _, exists := p.forwardedPorts[port]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := p.client.AddPortMapping("tcp", int(port), int(port), 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.forwardedPorts[port] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePortMapping disables port forwarding for the given port.
|
||||
func (p *PMP) DeletePortMapping(port uint16) error {
|
||||
p.forwardedPortsMtx.Lock()
|
||||
defer p.forwardedPortsMtx.Unlock()
|
||||
|
||||
if _, exists := p.forwardedPorts[port]; !exists {
|
||||
return fmt.Errorf("port %d is not being forwarded", port)
|
||||
}
|
||||
|
||||
_, err := p.client.AddPortMapping("tcp", int(port), 0, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(p.forwardedPorts, port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForwardedPorts returns a list of ports currently being forwarded.
|
||||
func (p *PMP) ForwardedPorts() []uint16 {
|
||||
p.forwardedPortsMtx.Lock()
|
||||
defer p.forwardedPortsMtx.Unlock()
|
||||
|
||||
ports := make([]uint16, 0, len(p.forwardedPorts))
|
||||
for port := range p.forwardedPorts {
|
||||
ports = append(ports, port)
|
||||
}
|
||||
|
||||
return ports
|
||||
}
|
||||
|
||||
// Name returns the name of the specific NAT traversal technique used.
|
||||
func (p *PMP) Name() string {
|
||||
return "NAT-PMP"
|
||||
}
|
58
nat/traversal.go
Normal file
58
nat/traversal.go
Normal file
@ -0,0 +1,58 @@
|
||||
package nat
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
var (
|
||||
// private24BitBlock contains the set of private IPv4 addresses within
|
||||
// the 10.0.0.0/8 adddress space.
|
||||
private24BitBlock *net.IPNet
|
||||
|
||||
// private20BitBlock contains the set of private IPv4 addresses within
|
||||
// the 172.16.0.0/12 address space.
|
||||
private20BitBlock *net.IPNet
|
||||
|
||||
// private16BitBlock contains the set of private IPv4 addresses within
|
||||
// the 192.168.0.0/16 address space.
|
||||
private16BitBlock *net.IPNet
|
||||
|
||||
// ErrMultipleNAT is an error returned when multiple NATs have been
|
||||
// detected.
|
||||
ErrMultipleNAT = errors.New("multiple NATs detected")
|
||||
)
|
||||
|
||||
func init() {
|
||||
_, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
|
||||
_, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
|
||||
_, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
|
||||
}
|
||||
|
||||
// Traversal is an interface that brings together the different NAT traversal
|
||||
// techniques.
|
||||
type Traversal interface {
|
||||
// ExternalIP returns the external IP address.
|
||||
ExternalIP() (net.IP, error)
|
||||
|
||||
// AddPortMapping adds a port mapping for the given port between the
|
||||
// private and public addresses.
|
||||
AddPortMapping(port uint16) error
|
||||
|
||||
// DeletePortMapping deletes a port mapping for the given port between
|
||||
// the private and public addresses.
|
||||
DeletePortMapping(port uint16) error
|
||||
|
||||
// ForwardedPorts returns the ports currently being forwarded using NAT
|
||||
// traversal.
|
||||
ForwardedPorts() []uint16
|
||||
|
||||
// Name returns the name of the specific NAT traversal technique used.
|
||||
Name() string
|
||||
}
|
||||
|
||||
// isPrivateIP determines if the IP is private.
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
return private24BitBlock.Contains(ip) ||
|
||||
private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
|
||||
}
|
112
nat/upnp.go
Normal file
112
nat/upnp.go
Normal file
@ -0,0 +1,112 @@
|
||||
package nat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
upnp "github.com/NebulousLabs/go-upnp"
|
||||
)
|
||||
|
||||
// Compile-time check to ensure UPnP implements the Traversal interface.
|
||||
var _ Traversal = (*UPnP)(nil)
|
||||
|
||||
// UPnP is a concrete implementation of the Traversal interface that uses the
|
||||
// UPnP technique.
|
||||
type UPnP struct {
|
||||
device *upnp.IGD
|
||||
|
||||
forwardedPortsMtx sync.Mutex
|
||||
forwardedPorts map[uint16]struct{}
|
||||
}
|
||||
|
||||
// DiscoverUPnP scans the local network for a UPnP enabled device.
|
||||
func DiscoverUPnP(ctx context.Context) (*UPnP, error) {
|
||||
// Scan the local network for a UPnP-enabled device.
|
||||
device, err := upnp.DiscoverCtx(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u := &UPnP{
|
||||
device: device,
|
||||
forwardedPorts: make(map[uint16]struct{}),
|
||||
}
|
||||
|
||||
// We'll then attempt to retrieve the external IP address of this
|
||||
// device to ensure it is not behind multiple NATs.
|
||||
if _, err := u.ExternalIP(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// ExternalIP returns the external IP address of the UPnP enabled device.
|
||||
func (u *UPnP) ExternalIP() (net.IP, error) {
|
||||
ip, err := u.device.ExternalIP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isPrivateIP(net.ParseIP(ip)) {
|
||||
return nil, ErrMultipleNAT
|
||||
}
|
||||
|
||||
return net.ParseIP(ip), nil
|
||||
}
|
||||
|
||||
// AddPortMapping enables port forwarding for the given port.
|
||||
func (u *UPnP) AddPortMapping(port uint16) error {
|
||||
u.forwardedPortsMtx.Lock()
|
||||
defer u.forwardedPortsMtx.Unlock()
|
||||
|
||||
if _, exists := u.forwardedPorts[port]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := u.device.Forward(port, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.forwardedPorts[port] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePortMapping disables port forwarding for the given port.
|
||||
func (u *UPnP) DeletePortMapping(port uint16) error {
|
||||
u.forwardedPortsMtx.Lock()
|
||||
defer u.forwardedPortsMtx.Unlock()
|
||||
|
||||
if _, exists := u.forwardedPorts[port]; !exists {
|
||||
return fmt.Errorf("port %d is not being forwarded", port)
|
||||
}
|
||||
|
||||
if err := u.device.Clear(port); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(u.forwardedPorts, port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForwardedPorts returns a list of ports currently being forwarded.
|
||||
func (u *UPnP) ForwardedPorts() []uint16 {
|
||||
u.forwardedPortsMtx.Lock()
|
||||
defer u.forwardedPortsMtx.Unlock()
|
||||
|
||||
ports := make([]uint16, 0, len(u.forwardedPorts))
|
||||
for port := range u.forwardedPorts {
|
||||
ports = append(ports, port)
|
||||
}
|
||||
|
||||
return ports
|
||||
}
|
||||
|
||||
// Name returns the name of the specific NAT traversal technique used.
|
||||
func (u *UPnP) Name() string {
|
||||
return "UPnP"
|
||||
}
|
Loading…
Reference in New Issue
Block a user