178 lines
3.7 KiB
Go
178 lines
3.7 KiB
Go
package Atom_Device
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"golang.zx2c4.com/wireguard/tun"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"sync/atomic"
|
|
)
|
|
|
|
type Device struct {
|
|
tun tun.Device
|
|
mtu int64
|
|
LocalIp net.IP
|
|
IpNet *net.IPNet
|
|
outboundChan chan *Packet
|
|
|
|
packetsPool sync.Pool
|
|
}
|
|
|
|
func NewDevice(existingTun tun.Device, ifName string, localIP net.IP, ipNet *net.IPNet) (*Device, error) {
|
|
var tunDevice tun.Device
|
|
var err error
|
|
|
|
if existingTun == nil {
|
|
tunDevice, err = NewTun(ifName, InterfaceMTU, localIP, ipNet.Mask)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create TUN device: %v", err)
|
|
}
|
|
} else {
|
|
tunDevice = existingTun
|
|
}
|
|
|
|
realMtu, err := tunDevice.MTU()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get TUN mtu: %v", err)
|
|
}
|
|
|
|
dev := &Device{
|
|
tun: tunDevice,
|
|
mtu: int64(realMtu),
|
|
LocalIp: localIP,
|
|
IpNet: ipNet,
|
|
outboundChan: make(chan *Packet, outboundChCap),
|
|
packetsPool: sync.Pool{
|
|
New: func() interface{} {
|
|
return new(Packet)
|
|
},
|
|
},
|
|
}
|
|
|
|
go dev.tunEventsReader()
|
|
go dev.tunPacketsReader()
|
|
|
|
return dev, nil
|
|
|
|
}
|
|
|
|
func (d *Device) GetTempPacket() *Packet {
|
|
return d.packetsPool.Get().(*Packet)
|
|
}
|
|
|
|
func (d *Device) PutTempPacket(data *Packet) {
|
|
data.clear()
|
|
d.packetsPool.Put(data)
|
|
}
|
|
|
|
// WritePacket TODO: batch write
|
|
func (d *Device) WritePacket(data *Packet, senderIP net.IP) error {
|
|
if data.IsIPv6 {
|
|
// TODO: implement. We need to set Device.localIP ipv6 instead of ipv4
|
|
return nil
|
|
} else {
|
|
copy(data.Src, senderIP)
|
|
copy(data.Dst, d.LocalIp)
|
|
}
|
|
data.RecalculateChecksum()
|
|
|
|
bufs := [][]byte{data.Buffer[:tunPacketOffset+len(data.Packet)]}
|
|
packetsCount, err := d.tun.Write(bufs, tunPacketOffset)
|
|
if err != nil {
|
|
return fmt.Errorf("write packet to tun: %v", err)
|
|
} else if packetsCount < len(bufs) {
|
|
fmt.Printf("wrote %d packets, len(bufs): %d", packetsCount, len(bufs))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *Device) OutboundChan() <-chan *Packet {
|
|
return d.outboundChan
|
|
}
|
|
|
|
func (d *Device) Close() error {
|
|
return d.tun.Close()
|
|
}
|
|
|
|
func (d *Device) tunEventsReader() {
|
|
for event := range d.tun.Events() {
|
|
if event&tun.EventMTUUpdate != 0 {
|
|
mtu, err := d.tun.MTU()
|
|
if err != nil {
|
|
fmt.Printf("Failed to load updated MTU of device: %v", err)
|
|
continue
|
|
}
|
|
if mtu < 0 {
|
|
fmt.Printf("MTU not updated to negative value: %v", mtu)
|
|
continue
|
|
}
|
|
var tooLarge string
|
|
if mtu > maxContentSize {
|
|
tooLarge = fmt.Sprintf(" (too large, capped at %v)", maxContentSize)
|
|
mtu = maxContentSize
|
|
}
|
|
old := atomic.SwapInt64(&d.mtu, int64(mtu))
|
|
if int(old) != mtu {
|
|
fmt.Printf("MTU updated: %v%s", mtu, tooLarge)
|
|
}
|
|
}
|
|
|
|
// TODO: check for event&tun.EventUp
|
|
if event&tun.EventDown != 0 {
|
|
fmt.Printf("Interface down requested")
|
|
// TODO
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *Device) tunPacketsReader() {
|
|
defer close(d.outboundChan)
|
|
|
|
batchSize := d.tun.BatchSize()
|
|
packets := make([]*Packet, batchSize)
|
|
bufs := make([][]byte, batchSize)
|
|
sizes := make([]int, batchSize)
|
|
|
|
for {
|
|
for i := range packets {
|
|
if packets[i] == nil {
|
|
packets[i] = d.GetTempPacket()
|
|
} else {
|
|
packets[i].clear()
|
|
}
|
|
bufs[i] = packets[i].Buffer[:]
|
|
sizes[i] = 0
|
|
}
|
|
|
|
packetsCount, err := d.tun.Read(bufs, sizes, tunPacketOffset)
|
|
for i := 0; i < packetsCount; i++ {
|
|
size := sizes[i]
|
|
if size == 0 || size > maxContentSize {
|
|
continue
|
|
}
|
|
|
|
data := packets[i]
|
|
data.Packet = data.Buffer[tunPacketOffset : size+tunPacketOffset]
|
|
okay := data.Parse()
|
|
if !okay {
|
|
continue
|
|
}
|
|
|
|
d.outboundChan <- data
|
|
packets[i] = nil
|
|
}
|
|
|
|
if errors.Is(err, tun.ErrTooManySegments) {
|
|
continue
|
|
} else if errors.Is(err, os.ErrClosed) {
|
|
return
|
|
} else if err != nil {
|
|
fmt.Printf("Failed to read packets from TUN device: %v", err)
|
|
return
|
|
}
|
|
}
|
|
}
|