atom-device/Device.go

178 lines
3.7 KiB
Go

package Atom_Device
import (
"errors"
"fmt"
"golang.zx2c4.com/wireguard/tun"
"net"
"os"
"sync"
"sync/atomic"
)
type Device struct {
tun tun.Device
mtu int64
LocalIp net.IP
IpNet *net.IPNet
outboundChan chan *Packet
packetsPool sync.Pool
}
func NewDevice(existingTun tun.Device, ifName string, localIP net.IP, ipNet *net.IPNet) (*Device, error) {
var tunDevice tun.Device
var err error
if existingTun == nil {
tunDevice, err = NewTun(ifName, InterfaceMTU, localIP, ipNet.Mask)
if err != nil {
return nil, fmt.Errorf("failed to create TUN device: %v", err)
}
} else {
tunDevice = existingTun
}
realMtu, err := tunDevice.MTU()
if err != nil {
return nil, fmt.Errorf("failed to get TUN mtu: %v", err)
}
dev := &Device{
tun: tunDevice,
mtu: int64(realMtu),
LocalIp: localIP,
IpNet: ipNet,
outboundChan: make(chan *Packet, outboundChCap),
packetsPool: sync.Pool{
New: func() interface{} {
return new(Packet)
},
},
}
go dev.tunEventsReader()
go dev.tunPacketsReader()
return dev, nil
}
func (d *Device) GetTempPacket() *Packet {
return d.packetsPool.Get().(*Packet)
}
func (d *Device) PutTempPacket(data *Packet) {
data.clear()
d.packetsPool.Put(data)
}
// WritePacket TODO: batch write
func (d *Device) WritePacket(data *Packet, senderIP net.IP) error {
if data.IsIPv6 {
// TODO: implement. We need to set Device.localIP ipv6 instead of ipv4
return nil
} else {
copy(data.Src, senderIP)
copy(data.Dst, d.LocalIp)
}
data.RecalculateChecksum()
bufs := [][]byte{data.Buffer[:tunPacketOffset+len(data.Packet)]}
packetsCount, err := d.tun.Write(bufs, tunPacketOffset)
if err != nil {
return fmt.Errorf("write packet to tun: %v", err)
} else if packetsCount < len(bufs) {
fmt.Printf("wrote %d packets, len(bufs): %d", packetsCount, len(bufs))
}
return nil
}
func (d *Device) OutboundChan() <-chan *Packet {
return d.outboundChan
}
func (d *Device) Close() error {
return d.tun.Close()
}
func (d *Device) tunEventsReader() {
for event := range d.tun.Events() {
if event&tun.EventMTUUpdate != 0 {
mtu, err := d.tun.MTU()
if err != nil {
fmt.Printf("Failed to load updated MTU of device: %v", err)
continue
}
if mtu < 0 {
fmt.Printf("MTU not updated to negative value: %v", mtu)
continue
}
var tooLarge string
if mtu > maxContentSize {
tooLarge = fmt.Sprintf(" (too large, capped at %v)", maxContentSize)
mtu = maxContentSize
}
old := atomic.SwapInt64(&d.mtu, int64(mtu))
if int(old) != mtu {
fmt.Printf("MTU updated: %v%s", mtu, tooLarge)
}
}
// TODO: check for event&tun.EventUp
if event&tun.EventDown != 0 {
fmt.Printf("Interface down requested")
// TODO
}
}
}
func (d *Device) tunPacketsReader() {
defer close(d.outboundChan)
batchSize := d.tun.BatchSize()
packets := make([]*Packet, batchSize)
bufs := make([][]byte, batchSize)
sizes := make([]int, batchSize)
for {
for i := range packets {
if packets[i] == nil {
packets[i] = d.GetTempPacket()
} else {
packets[i].clear()
}
bufs[i] = packets[i].Buffer[:]
sizes[i] = 0
}
packetsCount, err := d.tun.Read(bufs, sizes, tunPacketOffset)
for i := 0; i < packetsCount; i++ {
size := sizes[i]
if size == 0 || size > maxContentSize {
continue
}
data := packets[i]
data.Packet = data.Buffer[tunPacketOffset : size+tunPacketOffset]
okay := data.Parse()
if !okay {
continue
}
d.outboundChan <- data
packets[i] = nil
}
if errors.Is(err, tun.ErrTooManySegments) {
continue
} else if errors.Is(err, os.ErrClosed) {
return
} else if err != nil {
fmt.Printf("Failed to read packets from TUN device: %v", err)
return
}
}
}