- TunDevice初步封装完成准备作为独立包发布
This commit is contained in:
mori 2023-08-01 21:13:48 +08:00
commit a0929c353a
6 changed files with 485 additions and 0 deletions

178
Device.go Normal file
View File

@ -0,0 +1,178 @@
package Atom_Device
import (
"errors"
"fmt"
"golang.zx2c4.com/wireguard/tun"
"net"
"os"
"sync"
"sync/atomic"
)
type Device struct {
tun tun.Device
mtu int64
LocalIp net.IP
IpNet *net.IPNet
outboundChan chan *Packet
packetsPool sync.Pool
}
func NewDevice(existingTun tun.Device, ifName string, localIP net.IP, ipNet *net.IPNet) (*Device, error) {
var tunDevice tun.Device
var err error
if existingTun == nil {
tunDevice, err = NewTun(ifName, InterfaceMTU, localIP, ipNet.Mask)
if err != nil {
return nil, fmt.Errorf("failed to create TUN device: %v", err)
}
} else {
tunDevice = existingTun
}
realMtu, err := tunDevice.MTU()
if err != nil {
return nil, fmt.Errorf("failed to get TUN mtu: %v", err)
}
dev := &Device{
tun: tunDevice,
mtu: int64(realMtu),
LocalIp: localIP,
IpNet: ipNet,
outboundChan: make(chan *Packet, outboundChCap),
packetsPool: sync.Pool{
New: func() interface{} {
return new(Packet)
},
},
}
go dev.tunEventsReader()
go dev.tunPacketsReader()
return dev, nil
}
func (d *Device) GetTempPacket() *Packet {
return d.packetsPool.Get().(*Packet)
}
func (d *Device) PutTempPacket(data *Packet) {
data.clear()
d.packetsPool.Put(data)
}
// WritePacket TODO: batch write
func (d *Device) WritePacket(data *Packet, senderIP net.IP) error {
if data.IsIPv6 {
// TODO: implement. We need to set Device.localIP ipv6 instead of ipv4
return nil
} else {
copy(data.Src, senderIP)
copy(data.Dst, d.LocalIp)
}
data.RecalculateChecksum()
bufs := [][]byte{data.Buffer[:tunPacketOffset+len(data.Packet)]}
packetsCount, err := d.tun.Write(bufs, tunPacketOffset)
if err != nil {
return fmt.Errorf("write packet to tun: %v", err)
} else if packetsCount < len(bufs) {
fmt.Printf("wrote %d packets, len(bufs): %d", packetsCount, len(bufs))
}
return nil
}
func (d *Device) OutboundChan() <-chan *Packet {
return d.outboundChan
}
func (d *Device) Close() error {
return d.tun.Close()
}
func (d *Device) tunEventsReader() {
for event := range d.tun.Events() {
if event&tun.EventMTUUpdate != 0 {
mtu, err := d.tun.MTU()
if err != nil {
fmt.Printf("Failed to load updated MTU of device: %v", err)
continue
}
if mtu < 0 {
fmt.Printf("MTU not updated to negative value: %v", mtu)
continue
}
var tooLarge string
if mtu > maxContentSize {
tooLarge = fmt.Sprintf(" (too large, capped at %v)", maxContentSize)
mtu = maxContentSize
}
old := atomic.SwapInt64(&d.mtu, int64(mtu))
if int(old) != mtu {
fmt.Printf("MTU updated: %v%s", mtu, tooLarge)
}
}
// TODO: check for event&tun.EventUp
if event&tun.EventDown != 0 {
fmt.Printf("Interface down requested")
// TODO
}
}
}
func (d *Device) tunPacketsReader() {
defer close(d.outboundChan)
batchSize := d.tun.BatchSize()
packets := make([]*Packet, batchSize)
bufs := make([][]byte, batchSize)
sizes := make([]int, batchSize)
for {
for i := range packets {
if packets[i] == nil {
packets[i] = d.GetTempPacket()
} else {
packets[i].clear()
}
bufs[i] = packets[i].Buffer[:]
sizes[i] = 0
}
packetsCount, err := d.tun.Read(bufs, sizes, tunPacketOffset)
for i := 0; i < packetsCount; i++ {
size := sizes[i]
if size == 0 || size > maxContentSize {
continue
}
data := packets[i]
data.Packet = data.Buffer[tunPacketOffset : size+tunPacketOffset]
okay := data.Parse()
if !okay {
continue
}
d.outboundChan <- data
packets[i] = nil
}
if errors.Is(err, tun.ErrTooManySegments) {
continue
} else if errors.Is(err, os.ErrClosed) {
return
} else if err != nil {
fmt.Printf("Failed to read packets from TUN device: %v", err)
return
}
}
}

111
Packet.go Normal file
View File

@ -0,0 +1,111 @@
package Atom_Device
import (
"encoding/binary"
"git.sydch.com/Go-Module/Atom-Device/utils"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/device"
"io"
"net"
)
const (
InterfaceMTU = 3500
maxContentSize = InterfaceMTU * 2
outboundChCap = 50
tunPacketOffset = 14
ipv4offsetChecksum = 10
)
type Packet struct {
Buffer [maxContentSize]byte
Packet []byte
Src net.IP
Dst net.IP
IsIPv6 bool
}
func (data *Packet) clear() {
data.Packet = nil
data.Src = nil
data.Dst = nil
data.IsIPv6 = false
}
func (data *Packet) ReadFrom(stream io.Reader) (int64, error) {
var totalRead = tunPacketOffset
for {
n, err := stream.Read(data.Buffer[totalRead:])
totalRead += n
if err == io.EOF {
data.Packet = data.Buffer[tunPacketOffset:totalRead]
return int64(totalRead - tunPacketOffset), nil
} else if err != nil {
return int64(totalRead - tunPacketOffset), err
}
}
}
func (data *Packet) Parse() bool {
packet := data.Packet
switch version := packet[0] >> 4; version {
case ipv4.Version:
if len(packet) < ipv4.HeaderLen {
return false
}
data.Src = packet[device.IPv4offsetSrc : device.IPv4offsetSrc+net.IPv4len]
data.Dst = packet[device.IPv4offsetDst : device.IPv4offsetDst+net.IPv4len]
data.IsIPv6 = false
case ipv6.Version:
if len(packet) < ipv6.HeaderLen {
return false
}
data.Src = packet[device.IPv6offsetSrc : device.IPv6offsetSrc+net.IPv6len]
data.Dst = packet[device.IPv6offsetDst : device.IPv6offsetDst+net.IPv6len]
data.IsIPv6 = true
default:
return false
}
return true
}
func (data *Packet) ParseBuffer() {
var tmp [InterfaceMTU * 2]byte
for i := 0; i < len(data.Packet) && i+14 < len(tmp); i++ {
tmp[i+14] = data.Packet[i]
}
data.Buffer = tmp
}
func (data *Packet) RecalculateChecksum() {
const (
IPProtocolTCP = 6
IPProtocolUDP = 17
)
if data.IsIPv6 {
// TODO
} else {
ipHeaderLen := int(data.Packet[0]&0x0f) << 2
copy(data.Packet[ipv4offsetChecksum:], []byte{0, 0})
ipChecksum := utils.ChecksumIPv4Header(data.Packet[:ipHeaderLen])
binary.BigEndian.PutUint16(data.Packet[ipv4offsetChecksum:], ipChecksum)
switch protocol := data.Packet[9]; protocol {
case IPProtocolTCP:
tcpOffsetChecksum := ipHeaderLen + 16
copy(data.Packet[tcpOffsetChecksum:], []byte{0, 0})
checksum := utils.ChecksumIPv4TCPUDP(data.Packet[ipHeaderLen:], uint32(protocol), data.Src, data.Dst)
binary.BigEndian.PutUint16(data.Packet[tcpOffsetChecksum:], checksum)
case IPProtocolUDP:
udpOffsetChecksum := ipHeaderLen + 6
copy(data.Packet[udpOffsetChecksum:], []byte{0, 0})
checksum := utils.ChecksumIPv4TCPUDP(data.Packet[ipHeaderLen:], uint32(protocol), data.Src, data.Dst)
binary.BigEndian.PutUint16(data.Packet[udpOffsetChecksum:], checksum)
}
}
}

17
go.mod Normal file
View File

@ -0,0 +1,17 @@
module git.sydch.com/Go-Module/Atom-Device
go 1.20
require (
github.com/milosgajdos/tenus v0.0.3
golang.org/x/net v0.12.0
golang.org/x/sys v0.10.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/windows v0.5.3
)
require (
github.com/docker/libcontainer v2.2.1+incompatible // indirect
golang.org/x/crypto v0.11.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
)

49
iface_linux.go Normal file
View File

@ -0,0 +1,49 @@
//go:build linux && !android
// +build linux,!android
package Atom_Device
import (
"fmt"
"github.com/milosgajdos/tenus"
"golang.zx2c4.com/wireguard/tun"
"net"
)
func NewTun(ifName string, mtu int, localIP net.IP, ipMask net.IPMask) (tun.Device, error) {
ipNet := &net.IPNet{
IP: localIP.Mask(ipMask),
Mask: ipMask,
}
tunDevice, err := tun.CreateTUN(ifName, mtu)
if err != nil {
return nil, fmt.Errorf("create tun: %v", err)
}
link, err := tenus.NewLinkFrom(ifName)
if nil != err {
return nil, fmt.Errorf("unable to get interface info: %v", err)
}
err = link.SetLinkIp(localIP, ipNet)
if err != nil {
return nil, fmt.Errorf("unable to set IP (%s) to (%v on interface): %v", localIP, ipNet, err)
}
err = link.SetLinkUp()
if err != nil {
return nil, fmt.Errorf("unable to UP interface: %v", err)
}
return tunDevice, nil
}
func (d *Device) InterfaceName() (string, error) {
interfaceName, err := d.tun.Name()
if err != nil {
return "", err
}
return interfaceName, nil
}

70
iface_windows.go Normal file
View File

@ -0,0 +1,70 @@
//go:build windows
package Atom_Device
import (
"fmt"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/elevate"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"net"
"net/netip"
)
func init() {
var err error
tun.WintunTunnelType = "Atom"
if err != nil {
panic(err)
}
guid, err := windows.GUIDFromString("{E15D88C1-A15E-4CA1-9577-97FCF2CB995E}")
if err != nil {
panic(err)
}
tun.WintunStaticRequestedGUID = &guid
}
func NewTun(ifName string, mtu int, localIP net.IP, ipMask net.IPMask) (tun.Device, error) {
var tunDevice tun.Device
err := elevate.DoAsSystem(func() error {
var err error
tunDevice, err = tun.CreateTUN(ifName, mtu)
if err != nil {
return fmt.Errorf("create tun: %v", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("do as system: %v", err)
}
nativeTun := tunDevice.(*tun.NativeTun)
luid := winipcfg.LUID(nativeTun.LUID())
ones, _ := ipMask.Size()
netAddr := netip.MustParseAddr(localIP.String())
prefix := netip.PrefixFrom(netAddr, ones)
err = luid.SetIPAddresses([]netip.Prefix{prefix})
if err != nil {
return nil, fmt.Errorf("unable to setup interface IP: %v", err)
}
return tunDevice, nil
}
func (d *Device) InterfaceName() (string, error) {
nativeTun := d.tun.(*tun.NativeTun)
luid := winipcfg.LUID(nativeTun.LUID())
guid, err := luid.GUID()
if err != nil {
return "", err
}
return guid.String(), nil
}

60
utils/NetUtils.go Normal file
View File

@ -0,0 +1,60 @@
package utils
import (
"encoding/binary"
"net"
)
func ChecksumIPv4Header(buf []byte) uint16 {
var v uint32
for i := 0; i < len(buf)-1; i += 2 {
v += uint32(binary.BigEndian.Uint16(buf[i:]))
}
if len(buf)%2 == 1 {
v += uint32(buf[len(buf)-1]) << 8
}
for v > 0xffff {
v = (v >> 16) + (v & 0xffff)
}
return ^uint16(v)
}
func ChecksumIPv4TCPUDP(headerAndPayload []byte, protocol uint32, srcIP net.IP, dstIP net.IP) uint16 {
var csum uint32
csum += (uint32(srcIP[0]) + uint32(srcIP[2])) << 8
csum += uint32(srcIP[1]) + uint32(srcIP[3])
csum += (uint32(dstIP[0]) + uint32(dstIP[2])) << 8
csum += uint32(dstIP[1]) + uint32(dstIP[3])
totalLen := uint32(len(headerAndPayload))
csum += protocol
csum += totalLen & 0xffff
csum += totalLen >> 16
return tcpipChecksum(headerAndPayload, csum)
}
// Calculate the TCP/IP checksum defined in rfc1071. The passed-in csum is any
// initial checksum data that's already been computed.
// Borrowed from google/gopacket
func tcpipChecksum(data []byte, csum uint32) uint16 {
// to handle odd lengths, we loop to length - 1, incrementing by 2, then
// handle the last byte specifically by checking against the original
// length.
length := len(data) - 1
for i := 0; i < length; i += 2 {
// For our test packet, doing this manually is about 25% faster
// (740 ns vs. 1000ns) than doing it by calling binary.BigEndian.Uint16.
csum += uint32(data[i]) << 8
csum += uint32(data[i+1])
}
if len(data)%2 == 1 {
csum += uint32(data[length]) << 8
}
for csum > 0xffff {
csum = (csum >> 16) + (csum & 0xffff)
}
return ^uint16(csum)
}