commit
a0929c353a
|
@ -0,0 +1,178 @@
|
|||
package Atom_Device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
tun tun.Device
|
||||
mtu int64
|
||||
LocalIp net.IP
|
||||
IpNet *net.IPNet
|
||||
outboundChan chan *Packet
|
||||
|
||||
packetsPool sync.Pool
|
||||
}
|
||||
|
||||
func NewDevice(existingTun tun.Device, ifName string, localIP net.IP, ipNet *net.IPNet) (*Device, error) {
|
||||
var tunDevice tun.Device
|
||||
var err error
|
||||
|
||||
if existingTun == nil {
|
||||
tunDevice, err = NewTun(ifName, InterfaceMTU, localIP, ipNet.Mask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TUN device: %v", err)
|
||||
}
|
||||
} else {
|
||||
tunDevice = existingTun
|
||||
}
|
||||
|
||||
realMtu, err := tunDevice.MTU()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get TUN mtu: %v", err)
|
||||
}
|
||||
|
||||
dev := &Device{
|
||||
tun: tunDevice,
|
||||
mtu: int64(realMtu),
|
||||
LocalIp: localIP,
|
||||
IpNet: ipNet,
|
||||
outboundChan: make(chan *Packet, outboundChCap),
|
||||
packetsPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(Packet)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go dev.tunEventsReader()
|
||||
go dev.tunPacketsReader()
|
||||
|
||||
return dev, nil
|
||||
|
||||
}
|
||||
|
||||
func (d *Device) GetTempPacket() *Packet {
|
||||
return d.packetsPool.Get().(*Packet)
|
||||
}
|
||||
|
||||
func (d *Device) PutTempPacket(data *Packet) {
|
||||
data.clear()
|
||||
d.packetsPool.Put(data)
|
||||
}
|
||||
|
||||
// WritePacket TODO: batch write
|
||||
func (d *Device) WritePacket(data *Packet, senderIP net.IP) error {
|
||||
if data.IsIPv6 {
|
||||
// TODO: implement. We need to set Device.localIP ipv6 instead of ipv4
|
||||
return nil
|
||||
} else {
|
||||
copy(data.Src, senderIP)
|
||||
copy(data.Dst, d.LocalIp)
|
||||
}
|
||||
data.RecalculateChecksum()
|
||||
|
||||
bufs := [][]byte{data.Buffer[:tunPacketOffset+len(data.Packet)]}
|
||||
packetsCount, err := d.tun.Write(bufs, tunPacketOffset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write packet to tun: %v", err)
|
||||
} else if packetsCount < len(bufs) {
|
||||
fmt.Printf("wrote %d packets, len(bufs): %d", packetsCount, len(bufs))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Device) OutboundChan() <-chan *Packet {
|
||||
return d.outboundChan
|
||||
}
|
||||
|
||||
func (d *Device) Close() error {
|
||||
return d.tun.Close()
|
||||
}
|
||||
|
||||
func (d *Device) tunEventsReader() {
|
||||
for event := range d.tun.Events() {
|
||||
if event&tun.EventMTUUpdate != 0 {
|
||||
mtu, err := d.tun.MTU()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to load updated MTU of device: %v", err)
|
||||
continue
|
||||
}
|
||||
if mtu < 0 {
|
||||
fmt.Printf("MTU not updated to negative value: %v", mtu)
|
||||
continue
|
||||
}
|
||||
var tooLarge string
|
||||
if mtu > maxContentSize {
|
||||
tooLarge = fmt.Sprintf(" (too large, capped at %v)", maxContentSize)
|
||||
mtu = maxContentSize
|
||||
}
|
||||
old := atomic.SwapInt64(&d.mtu, int64(mtu))
|
||||
if int(old) != mtu {
|
||||
fmt.Printf("MTU updated: %v%s", mtu, tooLarge)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: check for event&tun.EventUp
|
||||
if event&tun.EventDown != 0 {
|
||||
fmt.Printf("Interface down requested")
|
||||
// TODO
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Device) tunPacketsReader() {
|
||||
defer close(d.outboundChan)
|
||||
|
||||
batchSize := d.tun.BatchSize()
|
||||
packets := make([]*Packet, batchSize)
|
||||
bufs := make([][]byte, batchSize)
|
||||
sizes := make([]int, batchSize)
|
||||
|
||||
for {
|
||||
for i := range packets {
|
||||
if packets[i] == nil {
|
||||
packets[i] = d.GetTempPacket()
|
||||
} else {
|
||||
packets[i].clear()
|
||||
}
|
||||
bufs[i] = packets[i].Buffer[:]
|
||||
sizes[i] = 0
|
||||
}
|
||||
|
||||
packetsCount, err := d.tun.Read(bufs, sizes, tunPacketOffset)
|
||||
for i := 0; i < packetsCount; i++ {
|
||||
size := sizes[i]
|
||||
if size == 0 || size > maxContentSize {
|
||||
continue
|
||||
}
|
||||
|
||||
data := packets[i]
|
||||
data.Packet = data.Buffer[tunPacketOffset : size+tunPacketOffset]
|
||||
okay := data.Parse()
|
||||
if !okay {
|
||||
continue
|
||||
}
|
||||
|
||||
d.outboundChan <- data
|
||||
packets[i] = nil
|
||||
}
|
||||
|
||||
if errors.Is(err, tun.ErrTooManySegments) {
|
||||
continue
|
||||
} else if errors.Is(err, os.ErrClosed) {
|
||||
return
|
||||
} else if err != nil {
|
||||
fmt.Printf("Failed to read packets from TUN device: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
package Atom_Device
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"git.sydch.com/Go-Module/Atom-Device/utils"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
InterfaceMTU = 3500
|
||||
maxContentSize = InterfaceMTU * 2
|
||||
outboundChCap = 50
|
||||
tunPacketOffset = 14
|
||||
ipv4offsetChecksum = 10
|
||||
)
|
||||
|
||||
type Packet struct {
|
||||
Buffer [maxContentSize]byte
|
||||
Packet []byte
|
||||
Src net.IP
|
||||
Dst net.IP
|
||||
IsIPv6 bool
|
||||
}
|
||||
|
||||
func (data *Packet) clear() {
|
||||
data.Packet = nil
|
||||
data.Src = nil
|
||||
data.Dst = nil
|
||||
data.IsIPv6 = false
|
||||
}
|
||||
|
||||
func (data *Packet) ReadFrom(stream io.Reader) (int64, error) {
|
||||
var totalRead = tunPacketOffset
|
||||
for {
|
||||
n, err := stream.Read(data.Buffer[totalRead:])
|
||||
totalRead += n
|
||||
if err == io.EOF {
|
||||
data.Packet = data.Buffer[tunPacketOffset:totalRead]
|
||||
return int64(totalRead - tunPacketOffset), nil
|
||||
} else if err != nil {
|
||||
return int64(totalRead - tunPacketOffset), err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (data *Packet) Parse() bool {
|
||||
packet := data.Packet
|
||||
switch version := packet[0] >> 4; version {
|
||||
case ipv4.Version:
|
||||
if len(packet) < ipv4.HeaderLen {
|
||||
return false
|
||||
}
|
||||
|
||||
data.Src = packet[device.IPv4offsetSrc : device.IPv4offsetSrc+net.IPv4len]
|
||||
data.Dst = packet[device.IPv4offsetDst : device.IPv4offsetDst+net.IPv4len]
|
||||
data.IsIPv6 = false
|
||||
case ipv6.Version:
|
||||
if len(packet) < ipv6.HeaderLen {
|
||||
return false
|
||||
}
|
||||
|
||||
data.Src = packet[device.IPv6offsetSrc : device.IPv6offsetSrc+net.IPv6len]
|
||||
data.Dst = packet[device.IPv6offsetDst : device.IPv6offsetDst+net.IPv6len]
|
||||
data.IsIPv6 = true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (data *Packet) ParseBuffer() {
|
||||
var tmp [InterfaceMTU * 2]byte
|
||||
for i := 0; i < len(data.Packet) && i+14 < len(tmp); i++ {
|
||||
tmp[i+14] = data.Packet[i]
|
||||
}
|
||||
data.Buffer = tmp
|
||||
}
|
||||
|
||||
func (data *Packet) RecalculateChecksum() {
|
||||
const (
|
||||
IPProtocolTCP = 6
|
||||
IPProtocolUDP = 17
|
||||
)
|
||||
|
||||
if data.IsIPv6 {
|
||||
// TODO
|
||||
} else {
|
||||
ipHeaderLen := int(data.Packet[0]&0x0f) << 2
|
||||
copy(data.Packet[ipv4offsetChecksum:], []byte{0, 0})
|
||||
ipChecksum := utils.ChecksumIPv4Header(data.Packet[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(data.Packet[ipv4offsetChecksum:], ipChecksum)
|
||||
|
||||
switch protocol := data.Packet[9]; protocol {
|
||||
case IPProtocolTCP:
|
||||
tcpOffsetChecksum := ipHeaderLen + 16
|
||||
copy(data.Packet[tcpOffsetChecksum:], []byte{0, 0})
|
||||
checksum := utils.ChecksumIPv4TCPUDP(data.Packet[ipHeaderLen:], uint32(protocol), data.Src, data.Dst)
|
||||
binary.BigEndian.PutUint16(data.Packet[tcpOffsetChecksum:], checksum)
|
||||
case IPProtocolUDP:
|
||||
udpOffsetChecksum := ipHeaderLen + 6
|
||||
copy(data.Packet[udpOffsetChecksum:], []byte{0, 0})
|
||||
checksum := utils.ChecksumIPv4TCPUDP(data.Packet[ipHeaderLen:], uint32(protocol), data.Src, data.Dst)
|
||||
binary.BigEndian.PutUint16(data.Packet[udpOffsetChecksum:], checksum)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
module git.sydch.com/Go-Module/Atom-Device
|
||||
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/milosgajdos/tenus v0.0.3
|
||||
golang.org/x/net v0.12.0
|
||||
golang.org/x/sys v0.10.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/docker/libcontainer v2.2.1+incompatible // indirect
|
||||
golang.org/x/crypto v0.11.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
)
|
|
@ -0,0 +1,49 @@
|
|||
//go:build linux && !android
|
||||
// +build linux,!android
|
||||
|
||||
package Atom_Device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/milosgajdos/tenus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"net"
|
||||
)
|
||||
|
||||
func NewTun(ifName string, mtu int, localIP net.IP, ipMask net.IPMask) (tun.Device, error) {
|
||||
ipNet := &net.IPNet{
|
||||
IP: localIP.Mask(ipMask),
|
||||
Mask: ipMask,
|
||||
}
|
||||
|
||||
tunDevice, err := tun.CreateTUN(ifName, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create tun: %v", err)
|
||||
}
|
||||
|
||||
link, err := tenus.NewLinkFrom(ifName)
|
||||
if nil != err {
|
||||
return nil, fmt.Errorf("unable to get interface info: %v", err)
|
||||
}
|
||||
|
||||
err = link.SetLinkIp(localIP, ipNet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to set IP (%s) to (%v on interface): %v", localIP, ipNet, err)
|
||||
}
|
||||
|
||||
err = link.SetLinkUp()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to UP interface: %v", err)
|
||||
}
|
||||
|
||||
return tunDevice, nil
|
||||
}
|
||||
|
||||
func (d *Device) InterfaceName() (string, error) {
|
||||
interfaceName, err := d.tun.Name()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return interfaceName, nil
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
//go:build windows
|
||||
|
||||
package Atom_Device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/windows/elevate"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
|
||||
tun.WintunTunnelType = "Atom"
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
guid, err := windows.GUIDFromString("{E15D88C1-A15E-4CA1-9577-97FCF2CB995E}")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tun.WintunStaticRequestedGUID = &guid
|
||||
}
|
||||
|
||||
func NewTun(ifName string, mtu int, localIP net.IP, ipMask net.IPMask) (tun.Device, error) {
|
||||
var tunDevice tun.Device
|
||||
|
||||
err := elevate.DoAsSystem(func() error {
|
||||
var err error
|
||||
tunDevice, err = tun.CreateTUN(ifName, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create tun: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do as system: %v", err)
|
||||
}
|
||||
|
||||
nativeTun := tunDevice.(*tun.NativeTun)
|
||||
luid := winipcfg.LUID(nativeTun.LUID())
|
||||
|
||||
ones, _ := ipMask.Size()
|
||||
netAddr := netip.MustParseAddr(localIP.String())
|
||||
prefix := netip.PrefixFrom(netAddr, ones)
|
||||
|
||||
err = luid.SetIPAddresses([]netip.Prefix{prefix})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to setup interface IP: %v", err)
|
||||
}
|
||||
|
||||
return tunDevice, nil
|
||||
}
|
||||
|
||||
func (d *Device) InterfaceName() (string, error) {
|
||||
nativeTun := d.tun.(*tun.NativeTun)
|
||||
luid := winipcfg.LUID(nativeTun.LUID())
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return guid.String(), nil
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
)
|
||||
|
||||
func ChecksumIPv4Header(buf []byte) uint16 {
|
||||
var v uint32
|
||||
for i := 0; i < len(buf)-1; i += 2 {
|
||||
v += uint32(binary.BigEndian.Uint16(buf[i:]))
|
||||
}
|
||||
if len(buf)%2 == 1 {
|
||||
v += uint32(buf[len(buf)-1]) << 8
|
||||
}
|
||||
for v > 0xffff {
|
||||
v = (v >> 16) + (v & 0xffff)
|
||||
}
|
||||
|
||||
return ^uint16(v)
|
||||
}
|
||||
|
||||
func ChecksumIPv4TCPUDP(headerAndPayload []byte, protocol uint32, srcIP net.IP, dstIP net.IP) uint16 {
|
||||
var csum uint32
|
||||
csum += (uint32(srcIP[0]) + uint32(srcIP[2])) << 8
|
||||
csum += uint32(srcIP[1]) + uint32(srcIP[3])
|
||||
csum += (uint32(dstIP[0]) + uint32(dstIP[2])) << 8
|
||||
csum += uint32(dstIP[1]) + uint32(dstIP[3])
|
||||
|
||||
totalLen := uint32(len(headerAndPayload))
|
||||
|
||||
csum += protocol
|
||||
csum += totalLen & 0xffff
|
||||
csum += totalLen >> 16
|
||||
|
||||
return tcpipChecksum(headerAndPayload, csum)
|
||||
}
|
||||
|
||||
// Calculate the TCP/IP checksum defined in rfc1071. The passed-in csum is any
|
||||
// initial checksum data that's already been computed.
|
||||
// Borrowed from google/gopacket
|
||||
func tcpipChecksum(data []byte, csum uint32) uint16 {
|
||||
// to handle odd lengths, we loop to length - 1, incrementing by 2, then
|
||||
// handle the last byte specifically by checking against the original
|
||||
// length.
|
||||
length := len(data) - 1
|
||||
for i := 0; i < length; i += 2 {
|
||||
// For our test packet, doing this manually is about 25% faster
|
||||
// (740 ns vs. 1000ns) than doing it by calling binary.BigEndian.Uint16.
|
||||
csum += uint32(data[i]) << 8
|
||||
csum += uint32(data[i+1])
|
||||
}
|
||||
if len(data)%2 == 1 {
|
||||
csum += uint32(data[length]) << 8
|
||||
}
|
||||
for csum > 0xffff {
|
||||
csum = (csum >> 16) + (csum & 0xffff)
|
||||
}
|
||||
return ^uint16(csum)
|
||||
}
|
Loading…
Reference in New Issue