Files
Xray-core/proxy/wireguard/client.go
T

490 lines
13 KiB
Go

package wireguard
import (
"context"
"fmt"
gonet "net"
"net/netip"
reflect "reflect"
"strings"
"sync"
"golang.zx2c4.com/wireguard/tun"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/dice"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/net/cnc"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet"
"golang.zx2c4.com/wireguard/device"
)
type Handler struct {
conf *DeviceConfig
policyManager policy.Manager
dns dns.Client
streamSettings *internet.MemoryStreamConfig
uplinkCounter stats.Counter
downlinkCounter stats.Counter
tun tun.Device
tnet *Net
dev *device.Device
mu sync.Mutex
}
func NewClient(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
v := core.MustFromContext(ctx)
p := v.GetFeature(policy.ManagerType()).(policy.Manager)
d := v.GetFeature(dns.ClientType()).(dns.Client)
streamSettings := session.StreamSettingsFromContext(ctx).(*internet.MemoryStreamConfig)
tag := session.FullHandlerFromContext(ctx).Tag()
var uplinkCounter stats.Counter
var downlinkCounter stats.Counter
if len(tag) > 0 && p.ForSystem().Stats.OutboundUplink {
statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
name := "outbound>>>" + tag + ">>>traffic>>>uplink"
c, _ := stats.GetOrRegisterCounter(statsManager, name)
if c != nil {
uplinkCounter = c
}
}
if len(tag) > 0 && p.ForSystem().Stats.OutboundDownlink {
statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
name := "outbound>>>" + tag + ">>>traffic>>>downlink"
c, _ := stats.GetOrRegisterCounter(statsManager, name)
if c != nil {
downlinkCounter = c
}
}
if len(conf.Peers) == 0 {
return nil, errors.New("empty peers")
}
for _, peer := range conf.Peers {
if peer.PublicKey == "" {
return nil, errors.New("peer without publickey")
}
if peer.Endpoint == "" {
return nil, errors.New("peer without endpoint")
}
}
localAddresses := make([]netip.Addr, 0, len(conf.Endpoint))
for _, localaddress := range conf.Endpoint {
addr, err := netip.ParseAddr(localaddress)
if err == nil {
localAddresses = append(localAddresses, addr)
continue
}
prefix, err := netip.ParsePrefix(localaddress)
if err == nil {
localAddresses = append(localAddresses, prefix.Addr())
continue
}
return nil, err
}
kernelTunSupported, err := KernelTunSupported()
if err != nil {
errors.LogWarningInner(context.Background(), err, "Failed to check kernel TUN support")
}
var tun tun.Device
var tnet *Net
if !conf.NoKernelTun && kernelTunSupported {
errors.LogWarning(context.Background(), "Using kernel TUN")
tun, tnet, err = createKernelTun(localAddresses, []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1.0.0.1"), netip.MustParseAddr("2606:4700:4700::1111"), netip.MustParseAddr("2606:4700:4700::1001")}, int(conf.Mtu))
} else {
errors.LogWarning(context.Background(), "Using gVisor TUN")
tun, tnet, _, err = CreateNetTUN(localAddresses, []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1.0.0.1"), netip.MustParseAddr("2606:4700:4700::1111"), netip.MustParseAddr("2606:4700:4700::1001")}, int(conf.Mtu), true)
}
if err != nil {
return nil, err
}
return &Handler{
conf: conf,
policyManager: p,
dns: d,
streamSettings: streamSettings,
uplinkCounter: uplinkCounter,
downlinkCounter: downlinkCounter,
tun: tun,
tnet: tnet,
}, nil
}
// Process implements proxy.Outbound.Process.
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds)-1]
if !ob.Target.IsValid() {
return errors.New("target not specified")
}
ob.Name = "wireguard"
ob.CanSpliceCopy = 3
if h.dev == nil {
if err := h.init(ctx); err != nil {
return err
}
}
if err := h.dev.Up(); err != nil {
return err
}
var addr netip.Addr
if ob.Target.Address.Family().IsDomain() {
ip, err := h.resolveRemote(ob.Target.Address.String())
if err != nil {
return errors.New("failed to resolve domain").Base(err)
}
addr, _ = netip.AddrFromSlice(ip)
} else {
addr, _ = netip.AddrFromSlice(ob.Target.Address.IP())
}
addrPort := netip.AddrPortFrom(addr, ob.Target.Port.Value())
if !addrPort.IsValid() {
return errors.New("invalid target ", ob.Target)
}
var newCtx context.Context
var newCancel context.CancelFunc
if session.TimeoutOnlyFromContext(ctx) {
newCtx, newCancel = context.WithCancel(context.Background())
}
sessionPolicy := h.policyManager.ForLevel(0)
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, func() {
cancel()
if newCancel != nil {
newCancel()
}
}, sessionPolicy.Timeouts.ConnectionIdle)
if newCtx != nil {
ctx = newCtx
}
var reader buf.Reader
var writer buf.Writer
switch ob.Target.Network {
case net.Network_TCP:
var conn net.Conn
var err error
if sessionPolicy.Timeouts.Handshake != 0 {
timeoutCtx, timeoutCancel := context.WithTimeout(ctx, sessionPolicy.Timeouts.Handshake)
conn, err = h.tnet.DialContextTCPAddrPort(timeoutCtx, addrPort)
timeoutCancel()
} else {
conn, err = h.tnet.DialContextTCPAddrPort(ctx, addrPort)
}
if err != nil {
return errors.New("failed to create TCP connection").Base(err)
}
defer conn.Close()
reader = buf.NewReader(conn)
writer = buf.NewWriter(conn)
case net.Network_UDP:
conn, err := h.tnet.DialUDPAddrPort(netip.AddrPort{}, addrPort)
if err != nil {
return errors.New("failed to create UDP connection").Base(err)
}
defer conn.Close()
c := &udpConnClient{
PacketConn: conn.(*internet.PacketConnWrapper).PacketConn,
resolveFunc: h.resolveRemote,
dest: gonet.UDPAddrFromAddrPort(addrPort),
}
reader = c
writer = c
default:
panic(ob.Target.Network)
}
requestFunc := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
}
responseFunc := func() error {
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer))
}
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return errors.New("connection ends").Base(err)
}
return nil
}
func (h *Handler) Close() (err error) {
h.mu.Lock()
defer h.mu.Unlock()
if h.dev != nil {
h.dev.Close()
h.dev = nil
h.tun = nil
} else if h.tun != nil {
h.tun.Close()
h.tun = nil
}
return nil
}
func (h *Handler) init(ctx context.Context) error {
h.mu.Lock()
defer h.mu.Unlock()
if h.dev != nil {
return nil
}
resolveFunc := h.resolveLocal
listenFunc := func() (net.PacketConn, error) {
dest, err := net.ParseDestination("udp:" + h.conf.Peers[0].Endpoint)
if err != nil {
return nil, err
}
conn, err := internet.DialSystem(ctx, dest, h.streamSettings.SocketSettings)
if err != nil {
return nil, err
}
var pktConn net.PacketConn
switch c := conn.(type) {
case *internet.PacketConnWrapper:
pktConn = c.PacketConn
case *cnc.Connection:
pktConn = &internet.FakePacketConn{Conn: c}
default:
panic(reflect.TypeOf(c))
}
if h.streamSettings.UdpmaskManager != nil {
newConn, err := h.streamSettings.UdpmaskManager.WrapPacketConnClient(pktConn)
if err != nil {
pktConn.Close()
return nil, errors.New("mask err").Base(err)
}
pktConn = newConn
}
if h.uplinkCounter != nil || h.downlinkCounter != nil {
pktConn = &PacketCounterConnection{
PacketConn: pktConn,
ReadCounter: h.downlinkCounter,
WriteCounter: h.uplinkCounter,
}
}
return pktConn, nil
}
bind := &bind{}
logger := &device.Logger{
Verbosef: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Debug,
Content: fmt.Sprintf(format, args...),
})
},
Errorf: func(format string, args ...any) {
log.Record(&log.GeneralMessage{
Severity: log.Severity_Error,
Content: fmt.Sprintf(format, args...),
})
},
}
dev := device.NewDevice(h.tun, bind, logger)
bind.resolveFunc = resolveFunc
bind.listenFunc = listenFunc
bind.downFunc = dev.Down
bind.reserved = h.conf.Reserved
var cfg strings.Builder
cfg.WriteString("private_key=" + h.conf.SecretKey + "\n")
for _, peer := range h.conf.Peers {
cfg.WriteString("public_key=" + peer.PublicKey + "\n")
if peer.PreSharedKey != "" {
cfg.WriteString("preshared_key=" + peer.PreSharedKey + "\n")
}
cfg.WriteString("endpoint=" + peer.Endpoint + "\n")
for _, ip := range peer.AllowedIps {
cfg.WriteString("allowed_ip=" + ip + "\n")
}
if peer.KeepAlive != "" {
cfg.WriteString("persistent_keepalive_interval=" + peer.KeepAlive + "\n")
}
}
err := dev.IpcSet(cfg.String())
if err != nil {
return err
}
err = dev.Up()
if err != nil {
return err
}
h.dev = dev
return nil
}
func (h *Handler) resolveLocal(host string) (net.IP, error) {
return resolveDomain(host, h.conf.DomainStrategy, func(host string) ([]net.IP, error) {
ips, _, err := h.dns.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: true})
return ips, err
})
}
func (h *Handler) resolveRemote(host string) (net.IP, error) {
return resolveDomain(host, h.conf.DomainStrategy, func(host string) ([]net.IP, error) {
addrs, err := h.tnet.LookupHost(host)
if err != nil {
return nil, err
}
ips := make([]net.IP, 0, len(addrs))
for _, addr := range addrs {
ips = append(ips, net.ParseIP(addr))
}
return ips, nil
})
}
func resolveDomain(host string, strategy DeviceConfig_DomainStrategy, lookupIP func(host string) ([]net.IP, error)) (net.IP, error) {
if ip := net.ParseIP(host); ip != nil {
return ip, nil
}
ips, err := lookupIP(host)
if err != nil {
return nil, err
}
if len(ips) == 0 {
return nil, dns.ErrEmptyResponse
}
var got4, got6 []net.IP
for _, ip := range ips {
if ip.To4() != nil {
got4 = append(got4, ip)
} else {
got6 = append(got6, ip)
}
}
var got []net.IP
switch strategy {
case DeviceConfig_FORCE_IP:
got = ips
return ips[dice.Roll(len(ips))], nil
case DeviceConfig_FORCE_IP4:
got = got4
case DeviceConfig_FORCE_IP6:
got = got6
case DeviceConfig_FORCE_IP46:
got = got4
if len(got) == 0 {
got = got6
}
case DeviceConfig_FORCE_IP64:
got = got6
if len(got) == 0 {
got = got4
}
default:
panic(strategy)
}
if len(got) == 0 {
return nil, dns.ErrEmptyResponse
}
return got[dice.Roll(len(got))], nil
}
type udpConnClient struct {
net.PacketConn
resolveFunc func(host string) (net.IP, error)
dest *net.UDPAddr
}
func (c *udpConnClient) ReadMultiBuffer() (buf.MultiBuffer, error) {
b := buf.New()
b.Resize(0, buf.Size)
n, addr, err := c.PacketConn.ReadFrom(b.Bytes())
if err != nil {
b.Release()
return nil, err
}
b.Resize(0, int32(n))
b.UDP = &net.Destination{
Address: net.IPAddress(addr.(*net.UDPAddr).IP),
Port: net.Port(addr.(*net.UDPAddr).Port),
Network: net.Network_UDP,
}
return buf.MultiBuffer{b}, nil
}
func (c *udpConnClient) WriteMultiBuffer(mb buf.MultiBuffer) error {
for i, b := range mb {
dst := c.dest
if b.UDP != nil {
if b.UDP.Address.Family().IsDomain() {
ip, err := c.resolveFunc(b.UDP.Address.String())
if err != nil {
errors.LogErrorInner(context.Background(), err, "drop packet to ", b.UDP, " with size ", len(b.Bytes()))
b.Release()
continue
}
dst = &net.UDPAddr{
IP: ip,
Port: int(b.UDP.Port),
}
} else {
dst = b.UDP.RawNetAddr().(*net.UDPAddr)
}
}
_, err := c.PacketConn.WriteTo(b.Bytes(), dst)
if err != nil {
buf.ReleaseMulti(mb[i:])
return err
}
b.Release()
}
return nil
}
type PacketCounterConnection struct {
net.PacketConn
ReadCounter stats.Counter
WriteCounter stats.Counter
}
func (c *PacketCounterConnection) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.PacketConn.ReadFrom(p)
if err == nil && c.ReadCounter != nil {
c.ReadCounter.Add(int64(n))
}
return
}
func (c *PacketCounterConnection) WriteTo(p []byte, addr net.Addr) (n int, err error) {
n, err = c.PacketConn.WriteTo(p, addr)
if err == nil && c.WriteCounter != nil {
c.WriteCounter.Add(int64(n))
}
return
}