This commit is contained in:
arraykeys
2021-05-11 13:40:18 +08:00
commit 94fffdbd84
159 changed files with 223308 additions and 0 deletions

363
utils/functions.go Executable file
View File

@@ -0,0 +1,363 @@
package utils
import (
"bufio"
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/exec"
"sync"
"runtime/debug"
"strconv"
"strings"
"time"
)
func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(isSrcErr bool, err error), cfn func(count int, isPositive bool), bytesPreSec float64) {
var one = &sync.Once{}
go func() {
defer func() {
if e := recover(); e != nil {
log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
var err error
var isSrcErr bool
if bytesPreSec > 0 {
newreader := NewReader(src)
newreader.SetRateLimit(bytesPreSec)
_, isSrcErr, err = ioCopy(dst, newreader, func(c int) {
cfn(c, false)
})
} else {
_, isSrcErr, err = ioCopy(dst, src, func(c int) {
cfn(c, false)
})
}
if err != nil {
one.Do(func() {
fn(isSrcErr, err)
})
}
}()
go func() {
defer func() {
if e := recover(); e != nil {
log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
var err error
var isSrcErr bool
if bytesPreSec > 0 {
newReader := NewReader(dst)
newReader.SetRateLimit(bytesPreSec)
_, isSrcErr, err = ioCopy(src, newReader, func(c int) {
cfn(c, true)
})
} else {
_, isSrcErr, err = ioCopy(src, dst, func(c int) {
cfn(c, true)
})
}
if err != nil {
one.Do(func() {
fn(isSrcErr, err)
})
}
}()
}
func ioCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, isSrcErr bool, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
if len(fn) == 1 {
fn[0](nw)
}
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
err = er
isSrcErr = true
break
}
}
return written, isSrcErr, err
}
func TlsConnectHost(host string, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) {
h := strings.Split(host, ":")
port, _ := strconv.Atoi(h[1])
return TlsConnect(h[0], port, timeout, certBytes, keyBytes)
}
func TlsConnect(host string, port, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) {
conf, err := getRequestTlsConfig(certBytes, keyBytes)
if err != nil {
return
}
_conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Duration(timeout)*time.Millisecond)
if err != nil {
return
}
return *tls.Client(_conn, conf), err
}
func getRequestTlsConfig(certBytes, keyBytes []byte) (conf *tls.Config, err error) {
var cert tls.Certificate
cert, err = tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return
}
serverCertPool := x509.NewCertPool()
ok := serverCertPool.AppendCertsFromPEM(certBytes)
if !ok {
err = errors.New("failed to parse root certificate")
}
conf = &tls.Config{
RootCAs: serverCertPool,
Certificates: []tls.Certificate{cert},
ServerName: "proxy",
InsecureSkipVerify: false,
}
return
}
func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
return
}
func ListenTls(ip string, port int, certBytes, keyBytes []byte) (ln *net.Listener, err error) {
var cert tls.Certificate
cert, err = tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return
}
clientCertPool := x509.NewCertPool()
ok := clientCertPool.AppendCertsFromPEM(certBytes)
if !ok {
err = errors.New("failed to parse root certificate")
}
config := &tls.Config{
ClientCAs: clientCertPool,
ServerName: "proxy",
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert,
}
_ln, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", ip, port), config)
if err == nil {
ln = &_ln
}
return
}
func PathExists(_path string) bool {
_, err := os.Stat(_path)
if err != nil && os.IsNotExist(err) {
return false
}
return true
}
func HTTPGet(URL string, timeout int) (err error) {
tr := &http.Transport{}
var resp *http.Response
var client *http.Client
defer func() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
tr.CloseIdleConnections()
}()
client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr}
resp, err = client.Get(URL)
if err != nil {
return
}
return
}
func CloseConn(conn *net.Conn) {
if conn != nil && *conn != nil {
(*conn).SetDeadline(time.Now().Add(time.Millisecond))
(*conn).Close()
}
}
func Keygen() (err error) {
cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048")
out, err := cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
cmd = exec.Command("sh", "-c", `openssl req -new -key proxy.key -x509 -days 3650 -out proxy.crt -subj /C=CN/ST=BJ/O="Localhost Ltd"/CN=proxy`)
out, err = cmd.CombinedOutput()
if err != nil {
log.Printf("err:%s", err)
return
}
fmt.Println(string(out))
return
}
func GetAllInterfaceAddr() ([]net.IP, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
addresses := []net.IP{}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
// if iface.Flags&net.FlagLoopback != 0 {
// continue // loopback interface
// }
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
// if ip == nil || ip.IsLoopback() {
// continue
// }
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
addresses = append(addresses, ip)
}
}
if len(addresses) == 0 {
return nil, fmt.Errorf("no address Found, net.InterfaceAddrs: %v", addresses)
}
//only need first
return addresses, nil
}
func UDPPacket(srcAddr string, packet []byte) []byte {
addrBytes := []byte(srcAddr)
addrLength := uint16(len(addrBytes))
bodyLength := uint16(len(packet))
pkg := new(bytes.Buffer)
binary.Write(pkg, binary.LittleEndian, addrLength)
binary.Write(pkg, binary.LittleEndian, addrBytes)
binary.Write(pkg, binary.LittleEndian, bodyLength)
binary.Write(pkg, binary.LittleEndian, packet)
return pkg.Bytes()
}
func ReadUDPPacket(conn *net.Conn) (srcAddr string, packet []byte, err error) {
reader := bufio.NewReader(*conn)
var addrLength uint16
var bodyLength uint16
err = binary.Read(reader, binary.LittleEndian, &addrLength)
if err != nil {
return
}
_srcAddr := make([]byte, addrLength)
n, err := reader.Read(_srcAddr)
if err != nil {
return
}
if n != int(addrLength) {
return
}
srcAddr = string(_srcAddr)
err = binary.Read(reader, binary.LittleEndian, &bodyLength)
if err != nil {
return
}
packet = make([]byte, bodyLength)
n, err = reader.Read(packet)
if err != nil {
return
}
if n != int(bodyLength) {
return
}
return
}
// type sockaddr struct {
// family uint16
// data [14]byte
// }
// const SO_ORIGINAL_DST = 80
// realServerAddress returns an intercepted connection's original destination.
// func realServerAddress(conn *net.Conn) (string, error) {
// tcpConn, ok := (*conn).(*net.TCPConn)
// if !ok {
// return "", errors.New("not a TCPConn")
// }
// file, err := tcpConn.File()
// if err != nil {
// return "", err
// }
// // To avoid potential problems from making the socket non-blocking.
// tcpConn.Close()
// *conn, err = net.FileConn(file)
// if err != nil {
// return "", err
// }
// defer file.Close()
// fd := file.Fd()
// var addr sockaddr
// size := uint32(unsafe.Sizeof(addr))
// err = getsockopt(int(fd), syscall.SOL_IP, SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&addr)), &size)
// if err != nil {
// return "", err
// }
// var ip net.IP
// switch addr.family {
// case syscall.AF_INET:
// ip = addr.data[2:6]
// default:
// return "", errors.New("unrecognized address family")
// }
// port := int(addr.data[0])<<8 + int(addr.data[1])
// return net.JoinHostPort(ip.String(), strconv.Itoa(port)), nil
// }
// func getsockopt(s int, level int, name int, val uintptr, vallen *uint32) (err error) {
// _, _, e1 := syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(s), uintptr(level), uintptr(name), uintptr(val), uintptr(unsafe.Pointer(vallen)), 0)
// if e1 != 0 {
// err = e1
// }
// return
// }

97
utils/io-limiter.go Normal file
View File

@@ -0,0 +1,97 @@
package utils
import (
"context"
"io"
"time"
"golang.org/x/time/rate"
)
const burstLimit = 1000 * 1000 * 1000
type Reader struct {
r io.Reader
limiter *rate.Limiter
ctx context.Context
}
type Writer struct {
w io.Writer
limiter *rate.Limiter
ctx context.Context
}
// NewReader returns a reader that implements io.Reader with rate limiting.
func NewReader(r io.Reader) *Reader {
return &Reader{
r: r,
ctx: context.Background(),
}
}
// NewReaderWithContext returns a reader that implements io.Reader with rate limiting.
func NewReaderWithContext(r io.Reader, ctx context.Context) *Reader {
return &Reader{
r: r,
ctx: ctx,
}
}
// NewWriter returns a writer that implements io.Writer with rate limiting.
func NewWriter(w io.Writer) *Writer {
return &Writer{
w: w,
ctx: context.Background(),
}
}
// NewWriterWithContext returns a writer that implements io.Writer with rate limiting.
func NewWriterWithContext(w io.Writer, ctx context.Context) *Writer {
return &Writer{
w: w,
ctx: ctx,
}
}
// SetRateLimit sets rate limit (bytes/sec) to the reader.
func (s *Reader) SetRateLimit(bytesPerSec float64) {
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
}
// Read reads bytes into p.
func (s *Reader) Read(p []byte) (int, error) {
if s.limiter == nil {
return s.r.Read(p)
}
n, err := s.r.Read(p)
if err != nil {
return n, err
}
if err := s.limiter.WaitN(s.ctx, n); err != nil {
return n, err
}
return n, nil
}
// SetRateLimit sets rate limit (bytes/sec) to the writer.
func (s *Writer) SetRateLimit(bytesPerSec float64) {
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
}
// Write writes bytes from p.
func (s *Writer) Write(p []byte) (int, error) {
if s.limiter == nil {
return s.w.Write(p)
}
n, err := s.w.Write(p)
if err != nil {
return n, err
}
if err := s.limiter.WaitN(s.ctx, n); err != nil {
return n, err
}
return n, err
}

315
utils/map.go Normal file
View File

@@ -0,0 +1,315 @@
package utils
import (
"encoding/json"
"sync"
)
var SHARD_COUNT = 32
// A "thread" safe map of type string:Anything.
// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
type ConcurrentMap []*ConcurrentMapShared
// A "thread" safe string to anything map.
type ConcurrentMapShared struct {
items map[string]interface{}
sync.RWMutex // Read Write mutex, guards access to internal map.
}
// Creates a new concurrent map.
func NewConcurrentMap() ConcurrentMap {
m := make(ConcurrentMap, SHARD_COUNT)
for i := 0; i < SHARD_COUNT; i++ {
m[i] = &ConcurrentMapShared{items: make(map[string]interface{})}
}
return m
}
// GetShard returns shard under given key
func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared {
return m[uint(fnv32(key))%uint(SHARD_COUNT)]
}
func (m ConcurrentMap) MSet(data map[string]interface{}) {
for key, value := range data {
shard := m.GetShard(key)
shard.Lock()
shard.items[key] = value
shard.Unlock()
}
}
// Sets the given value under the specified key.
func (m ConcurrentMap) Set(key string, value interface{}) {
// Get map shard.
shard := m.GetShard(key)
shard.Lock()
shard.items[key] = value
shard.Unlock()
}
// Callback to return new element to be inserted into the map
// It is called while lock is held, therefore it MUST NOT
// try to access other keys in same map, as it can lead to deadlock since
// Go sync.RWLock is not reentrant
type UpsertCb func(exist bool, valueInMap interface{}, newValue interface{}) interface{}
// Insert or Update - updates existing element or inserts a new one using UpsertCb
func (m ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res interface{}) {
shard := m.GetShard(key)
shard.Lock()
v, ok := shard.items[key]
res = cb(ok, v, value)
shard.items[key] = res
shard.Unlock()
return res
}
// Sets the given value under the specified key if no value was associated with it.
func (m ConcurrentMap) SetIfAbsent(key string, value interface{}) bool {
// Get map shard.
shard := m.GetShard(key)
shard.Lock()
_, ok := shard.items[key]
if !ok {
shard.items[key] = value
}
shard.Unlock()
return !ok
}
// Get retrieves an element from map under given key.
func (m ConcurrentMap) Get(key string) (interface{}, bool) {
// Get shard
shard := m.GetShard(key)
shard.RLock()
// Get item from shard.
val, ok := shard.items[key]
shard.RUnlock()
return val, ok
}
// Count returns the number of elements within the map.
func (m ConcurrentMap) Count() int {
count := 0
for i := 0; i < SHARD_COUNT; i++ {
shard := m[i]
shard.RLock()
count += len(shard.items)
shard.RUnlock()
}
return count
}
// Looks up an item under specified key
func (m ConcurrentMap) Has(key string) bool {
// Get shard
shard := m.GetShard(key)
shard.RLock()
// See if element is within shard.
_, ok := shard.items[key]
shard.RUnlock()
return ok
}
// Remove removes an element from the map.
func (m ConcurrentMap) Remove(key string) {
// Try to get shard.
shard := m.GetShard(key)
shard.Lock()
delete(shard.items, key)
shard.Unlock()
}
// Pop removes an element from the map and returns it
func (m ConcurrentMap) Pop(key string) (v interface{}, exists bool) {
// Try to get shard.
shard := m.GetShard(key)
shard.Lock()
v, exists = shard.items[key]
delete(shard.items, key)
shard.Unlock()
return v, exists
}
// IsEmpty checks if map is empty.
func (m ConcurrentMap) IsEmpty() bool {
return m.Count() == 0
}
// Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
type Tuple struct {
Key string
Val interface{}
}
// Iter returns an iterator which could be used in a for range loop.
//
// Deprecated: using IterBuffered() will get a better performence
func (m ConcurrentMap) Iter() <-chan Tuple {
chans := snapshot(m)
ch := make(chan Tuple)
go fanIn(chans, ch)
return ch
}
// IterBuffered returns a buffered iterator which could be used in a for range loop.
func (m ConcurrentMap) IterBuffered() <-chan Tuple {
chans := snapshot(m)
total := 0
for _, c := range chans {
total += cap(c)
}
ch := make(chan Tuple, total)
go fanIn(chans, ch)
return ch
}
// Returns a array of channels that contains elements in each shard,
// which likely takes a snapshot of `m`.
// It returns once the size of each buffered channel is determined,
// before all the channels are populated using goroutines.
func snapshot(m ConcurrentMap) (chans []chan Tuple) {
chans = make([]chan Tuple, SHARD_COUNT)
wg := sync.WaitGroup{}
wg.Add(SHARD_COUNT)
// Foreach shard.
for index, shard := range m {
go func(index int, shard *ConcurrentMapShared) {
// Foreach key, value pair.
shard.RLock()
chans[index] = make(chan Tuple, len(shard.items))
wg.Done()
for key, val := range shard.items {
chans[index] <- Tuple{key, val}
}
shard.RUnlock()
close(chans[index])
}(index, shard)
}
wg.Wait()
return chans
}
// fanIn reads elements from channels `chans` into channel `out`
func fanIn(chans []chan Tuple, out chan Tuple) {
wg := sync.WaitGroup{}
wg.Add(len(chans))
for _, ch := range chans {
go func(ch chan Tuple) {
for t := range ch {
out <- t
}
wg.Done()
}(ch)
}
wg.Wait()
close(out)
}
// Items returns all items as map[string]interface{}
func (m ConcurrentMap) Items() map[string]interface{} {
tmp := make(map[string]interface{})
// Insert items to temporary map.
for item := range m.IterBuffered() {
tmp[item.Key] = item.Val
}
return tmp
}
// Iterator callback,called for every key,value found in
// maps. RLock is held for all calls for a given shard
// therefore callback sess consistent view of a shard,
// but not across the shards
type IterCb func(key string, v interface{})
// Callback based iterator, cheapest way to read
// all elements in a map.
func (m ConcurrentMap) IterCb(fn IterCb) {
for idx := range m {
shard := (m)[idx]
shard.RLock()
for key, value := range shard.items {
fn(key, value)
}
shard.RUnlock()
}
}
// Keys returns all keys as []string
func (m ConcurrentMap) Keys() []string {
count := m.Count()
ch := make(chan string, count)
go func() {
// Foreach shard.
wg := sync.WaitGroup{}
wg.Add(SHARD_COUNT)
for _, shard := range m {
go func(shard *ConcurrentMapShared) {
// Foreach key, value pair.
shard.RLock()
for key := range shard.items {
ch <- key
}
shard.RUnlock()
wg.Done()
}(shard)
}
wg.Wait()
close(ch)
}()
// Generate keys
keys := make([]string, 0, count)
for k := range ch {
keys = append(keys, k)
}
return keys
}
//Reviles ConcurrentMap "private" variables to json marshal.
func (m ConcurrentMap) MarshalJSON() ([]byte, error) {
// Create a temporary map, which will hold all item spread across shards.
tmp := make(map[string]interface{})
// Insert items to temporary map.
for item := range m.IterBuffered() {
tmp[item.Key] = item.Val
}
return json.Marshal(tmp)
}
func fnv32(key string) uint32 {
hash := uint32(2166136261)
const prime32 = uint32(16777619)
for i := 0; i < len(key); i++ {
hash *= prime32
hash ^= uint32(key[i])
}
return hash
}
// Concurrent map uses Interface{} as its value, therefor JSON Unmarshal
// will probably won't know which to type to unmarshal into, in such case
// we'll end up with a value of type map[string]interface{}, In most cases this isn't
// out value type, this is why we've decided to remove this functionality.
// func (m *ConcurrentMap) UnmarshalJSON(b []byte) (err error) {
// // Reverse process of Marshal.
// tmp := make(map[string]interface{})
// // Unmarshal into a single map.
// if err := json.Unmarshal(b, &tmp); err != nil {
// return nil
// }
// // foreach key,value pair in temporary map insert into our concurrent map.
// for key, val := range tmp {
// m.Set(key, val)
// }
// return nil
// }

145
utils/pool.go Executable file
View File

@@ -0,0 +1,145 @@
package utils
import (
"log"
"sync"
"time"
)
//ConnPool to use
type ConnPool interface {
Get() (conn interface{}, err error)
Put(conn interface{})
ReleaseAll()
Len() (length int)
}
type poolConfig struct {
Factory func() (interface{}, error)
IsActive func(interface{}) bool
Release func(interface{})
InitialCap int
MaxCap int
}
func NewConnPool(poolConfig poolConfig) (pool ConnPool, err error) {
p := netPool{
config: poolConfig,
conns: make(chan interface{}, poolConfig.MaxCap),
lock: &sync.Mutex{},
}
//log.Printf("pool MaxCap:%d", poolConfig.MaxCap)
if poolConfig.MaxCap > 0 {
err = p.initAutoFill(false)
if err == nil {
p.initAutoFill(true)
}
}
return &p, nil
}
type netPool struct {
conns chan interface{}
lock *sync.Mutex
config poolConfig
}
func (p *netPool) initAutoFill(async bool) (err error) {
var worker = func() (err error) {
for {
//log.Printf("pool fill: %v , len: %d", p.Len() <= p.config.InitialCap/2, p.Len())
if p.Len() <= p.config.InitialCap/2 {
p.lock.Lock()
errN := 0
for i := 0; i < p.config.InitialCap; i++ {
c, err := p.config.Factory()
if err != nil {
errN++
if async {
continue
} else {
p.lock.Unlock()
return err
}
}
select {
case p.conns <- c:
default:
p.config.Release(c)
break
}
if p.Len() >= p.config.InitialCap {
break
}
}
if errN > 0 {
log.Printf("fill conn pool fail , ERRN:%d", errN)
}
p.lock.Unlock()
}
if !async {
return
}
time.Sleep(time.Second * 2)
}
}
if async {
go worker()
} else {
err = worker()
}
return
}
func (p *netPool) Get() (conn interface{}, err error) {
// defer func() {
// log.Printf("pool len : %d", p.Len())
// }()
p.lock.Lock()
defer p.lock.Unlock()
// for {
select {
case conn = <-p.conns:
if p.config.IsActive(conn) {
return
}
p.config.Release(conn)
default:
conn, err = p.config.Factory()
if err != nil {
return nil, err
}
return conn, nil
}
// }
return
}
func (p *netPool) Put(conn interface{}) {
if conn == nil {
return
}
p.lock.Lock()
defer p.lock.Unlock()
if !p.config.IsActive(conn) {
p.config.Release(conn)
}
select {
case p.conns <- conn:
default:
p.config.Release(conn)
}
}
func (p *netPool) ReleaseAll() {
p.lock.Lock()
defer p.lock.Unlock()
close(p.conns)
for c := range p.conns {
p.config.Release(c)
}
p.conns = make(chan interface{}, p.config.InitialCap)
}
func (p *netPool) Len() (length int) {
return len(p.conns)
}

126
utils/serve-channel.go Normal file
View File

@@ -0,0 +1,126 @@
package utils
import (
"fmt"
"log"
"net"
"runtime/debug"
)
type ServerChannel struct {
ip string
port int
Listener *net.Listener
UDPListener *net.UDPConn
errAcceptHandler func(err error)
}
func NewServerChannel(ip string, port int) ServerChannel {
return ServerChannel{
ip: ip,
port: port,
errAcceptHandler: func(err error) {
fmt.Printf("accept error , ERR:%s", err)
},
}
}
func (sc *ServerChannel) SetErrAcceptHandler(fn func(err error)) {
sc.errAcceptHandler = fn
}
func (sc *ServerChannel) ListenTls(certBytes, keyBytes []byte, fn func(conn net.Conn)) (err error) {
sc.Listener, err = ListenTls(sc.ip, sc.port, certBytes, keyBytes)
if err == nil {
go func() {
defer func() {
if e := recover(); e != nil {
log.Printf("ListenTls crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
for {
var conn net.Conn
conn, err = (*sc.Listener).Accept()
if err == nil {
go func() {
defer func() {
if e := recover(); e != nil {
log.Printf("connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
fn(conn)
}()
} else {
sc.errAcceptHandler(err)
(*sc.Listener).Close()
break
}
}
}()
}
return
}
func (sc *ServerChannel) ListenTCP(fn func(conn net.Conn)) (err error) {
var l net.Listener
l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sc.ip, sc.port))
if err == nil {
sc.Listener = &l
go func() {
defer func() {
if e := recover(); e != nil {
log.Printf("ListenTCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
for {
var conn net.Conn
conn, err = (*sc.Listener).Accept()
if err == nil {
go func() {
defer func() {
if e := recover(); e != nil {
log.Printf("connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
fn(conn)
}()
} else {
sc.errAcceptHandler(err)
break
}
}
}()
}
return
}
func (sc *ServerChannel) ListenUDP(fn func(packet []byte, localAddr, srcAddr *net.UDPAddr)) (err error) {
addr := &net.UDPAddr{IP: net.ParseIP(sc.ip), Port: sc.port}
l, err := net.ListenUDP("udp", addr)
if err == nil {
sc.UDPListener = l
go func() {
defer func() {
if e := recover(); e != nil {
log.Printf("ListenUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
for {
var buf = make([]byte, 2048)
n, srcAddr, err := (*sc.UDPListener).ReadFromUDP(buf)
if err == nil {
packet := buf[0:n]
go func() {
defer func() {
if e := recover(); e != nil {
log.Printf("udp data handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
fn(packet, addr, srcAddr)
}()
} else {
sc.errAcceptHandler(err)
break
}
}
}()
}
return
}

462
utils/structs.go Normal file
View File

@@ -0,0 +1,462 @@
package utils
import (
"bytes"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/url"
"strings"
"time"
)
type Checker struct {
data ConcurrentMap
blockedMap ConcurrentMap
directMap ConcurrentMap
interval int64
timeout int
}
type CheckerItem struct {
IsHTTPS bool
Method string
URL string
Domain string
Host string
Data []byte
SuccessCount uint
FailCount uint
}
//NewChecker args:
//timeout : tcp timeout milliseconds ,connect to host
//interval: recheck domain interval seconds
func NewChecker(timeout int, interval int64, blockedFile, directFile string) Checker {
ch := Checker{
data: NewConcurrentMap(),
interval: interval,
timeout: timeout,
}
ch.blockedMap = ch.loadMap(blockedFile)
ch.directMap = ch.loadMap(directFile)
if !ch.blockedMap.IsEmpty() {
log.Printf("blocked file loaded , domains : %d", ch.blockedMap.Count())
}
if !ch.directMap.IsEmpty() {
log.Printf("direct file loaded , domains : %d", ch.directMap.Count())
}
ch.start()
return ch
}
func (c *Checker) loadMap(f string) (dataMap ConcurrentMap) {
dataMap = NewConcurrentMap()
if PathExists(f) {
_contents, err := ioutil.ReadFile(f)
if err != nil {
log.Printf("load file err:%s", err)
return
}
for _, line := range strings.Split(string(_contents), "\n") {
line = strings.Trim(line, "\r \t")
if line != "" {
dataMap.Set(line, true)
}
}
}
return
}
func (c *Checker) start() {
go func() {
for {
for _, v := range c.data.Items() {
go func(item CheckerItem) {
if c.isNeedCheck(item) {
//log.Printf("check %s", item.Domain)
var conn net.Conn
var err error
if item.IsHTTPS {
conn, err = ConnectHost(item.Host, c.timeout)
if err == nil {
conn.SetDeadline(time.Now().Add(time.Millisecond))
conn.Close()
}
} else {
err = HTTPGet(item.URL, c.timeout)
}
if err != nil {
item.FailCount = item.FailCount + 1
} else {
item.SuccessCount = item.SuccessCount + 1
}
c.data.Set(item.Host, item)
}
}(v.(CheckerItem))
}
time.Sleep(time.Second * time.Duration(c.interval))
}
}()
}
func (c *Checker) isNeedCheck(item CheckerItem) bool {
var minCount uint = 5
if (item.SuccessCount >= minCount && item.SuccessCount > item.FailCount) ||
(item.FailCount >= minCount && item.SuccessCount > item.FailCount) ||
c.domainIsInMap(item.Host, false) ||
c.domainIsInMap(item.Host, true) {
return false
}
return true
}
func (c *Checker) IsBlocked(address string) (blocked bool, failN, successN uint) {
if c.domainIsInMap(address, true) {
//log.Printf("%s in blocked ? true", address)
return true, 0, 0
}
if c.domainIsInMap(address, false) {
//log.Printf("%s in direct ? true", address)
return false, 0, 0
}
_item, ok := c.data.Get(address)
if !ok {
//log.Printf("%s not in map, blocked true", address)
return true, 0, 0
}
item := _item.(CheckerItem)
return item.FailCount >= item.SuccessCount, item.FailCount, item.SuccessCount
}
func (c *Checker) domainIsInMap(address string, blockedMap bool) bool {
u, err := url.Parse("http://" + address)
if err != nil {
log.Printf("blocked check , url parse err:%s", err)
return true
}
domainSlice := strings.Split(u.Hostname(), ".")
if len(domainSlice) > 1 {
subSlice := domainSlice[:len(domainSlice)-1]
topDomain := strings.Join(domainSlice[len(domainSlice)-1:], ".")
checkDomain := topDomain
for i := len(subSlice) - 1; i >= 0; i-- {
checkDomain = subSlice[i] + "." + checkDomain
if !blockedMap && c.directMap.Has(checkDomain) {
return true
}
if blockedMap && c.blockedMap.Has(checkDomain) {
return true
}
}
}
return false
}
func (c *Checker) Add(address string, isHTTPS bool, method, URL string, data []byte) {
if c.domainIsInMap(address, false) || c.domainIsInMap(address, true) {
return
}
if !isHTTPS && strings.ToLower(method) != "get" {
return
}
var item CheckerItem
u := strings.Split(address, ":")
item = CheckerItem{
URL: URL,
Domain: u[0],
Host: address,
Data: data,
IsHTTPS: isHTTPS,
Method: method,
}
c.data.SetIfAbsent(item.Host, item)
}
type BasicAuth struct {
data ConcurrentMap
}
func NewBasicAuth() BasicAuth {
return BasicAuth{
data: NewConcurrentMap(),
}
}
func (ba *BasicAuth) AddFromFile(file string) (n int, err error) {
_content, err := ioutil.ReadFile(file)
if err != nil {
return
}
userpassArr := strings.Split(strings.Replace(string(_content), "\r", "", -1), "\n")
for _, userpass := range userpassArr {
if strings.HasPrefix("#", userpass) {
continue
}
u := strings.Split(strings.Trim(userpass, " "), ":")
if len(u) == 2 {
ba.data.Set(u[0], u[1])
n++
}
}
return
}
func (ba *BasicAuth) Add(userpassArr []string) (n int) {
for _, userpass := range userpassArr {
u := strings.Split(userpass, ":")
if len(u) == 2 {
ba.data.Set(u[0], u[1])
n++
}
}
return
}
func (ba *BasicAuth) Check(userpass string) (ok bool) {
u := strings.Split(strings.Trim(userpass, " "), ":")
if len(u) == 2 {
if p, _ok := ba.data.Get(u[0]); _ok {
return p.(string) == u[1]
}
}
return
}
func (ba *BasicAuth) Total() (n int) {
n = ba.data.Count()
return
}
type HTTPRequest struct {
HeadBuf []byte
conn *net.Conn
Host string
Method string
URL string
hostOrURL string
isBasicAuth bool
basicAuth *BasicAuth
}
func NewHTTPRequest(inConn *net.Conn, bufSize int, isBasicAuth bool, basicAuth *BasicAuth) (req HTTPRequest, err error) {
buf := make([]byte, bufSize)
len := 0
req = HTTPRequest{
conn: inConn,
}
len, err = (*inConn).Read(buf[:])
if err != nil {
if err != io.EOF {
err = fmt.Errorf("http decoder read err:%s", err)
}
CloseConn(inConn)
return
}
req.HeadBuf = buf[:len]
index := bytes.IndexByte(req.HeadBuf, '\n')
if index == -1 {
err = fmt.Errorf("http decoder data line err:%s", string(req.HeadBuf)[:50])
CloseConn(inConn)
return
}
fmt.Sscanf(string(req.HeadBuf[:index]), "%s%s", &req.Method, &req.hostOrURL)
if req.Method == "" || req.hostOrURL == "" {
err = fmt.Errorf("http decoder data err:%s", string(req.HeadBuf)[:50])
CloseConn(inConn)
return
}
req.Method = strings.ToUpper(req.Method)
req.isBasicAuth = isBasicAuth
req.basicAuth = basicAuth
log.Printf("%s:%s", req.Method, req.hostOrURL)
if req.IsHTTPS() {
err = req.HTTPS()
} else {
err = req.HTTP()
}
return
}
func (req *HTTPRequest) HTTP() (err error) {
if req.isBasicAuth {
err = req.BasicAuth()
if err != nil {
return
}
}
req.URL, err = req.getHTTPURL()
if err == nil {
u, _ := url.Parse(req.URL)
req.Host = u.Host
req.addPortIfNot()
}
return
}
func (req *HTTPRequest) HTTPS() (err error) {
req.Host = req.hostOrURL
req.addPortIfNot()
//_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
return
}
func (req *HTTPRequest) HTTPSReply() (err error) {
_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
return
}
func (req *HTTPRequest) IsHTTPS() bool {
return req.Method == "CONNECT"
}
func (req *HTTPRequest) BasicAuth() (err error) {
//log.Printf("request :%s", string(b[:n]))
authorization, err := req.getHeader("Authorization")
if err != nil {
fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized")
CloseConn(req.conn)
return
}
//log.Printf("Authorization:%s", authorization)
basic := strings.Fields(authorization)
if len(basic) != 2 {
err = fmt.Errorf("authorization data error,ERR:%s", authorization)
CloseConn(req.conn)
return
}
user, err := base64.StdEncoding.DecodeString(basic[1])
if err != nil {
err = fmt.Errorf("authorization data parse error,ERR:%s", err)
CloseConn(req.conn)
return
}
authOk := (*req.basicAuth).Check(string(user))
//log.Printf("auth %s,%v", string(user), authOk)
if !authOk {
fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized")
CloseConn(req.conn)
err = fmt.Errorf("basic auth fail")
return
}
return
}
func (req *HTTPRequest) getHTTPURL() (URL string, err error) {
if !strings.HasPrefix(req.hostOrURL, "/") {
return req.hostOrURL, nil
}
_host, err := req.getHeader("host")
if err != nil {
return
}
URL = fmt.Sprintf("http://%s%s", _host, req.hostOrURL)
return
}
func (req *HTTPRequest) getHeader(key string) (val string, err error) {
key = strings.ToUpper(key)
lines := strings.Split(string(req.HeadBuf), "\r\n")
for _, line := range lines {
line := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2)
if len(line) == 2 {
k := strings.ToUpper(strings.Trim(line[0], " "))
v := strings.Trim(line[1], " ")
if key == k {
val = v
return
}
}
}
err = fmt.Errorf("can not find HOST header")
return
}
func (req *HTTPRequest) addPortIfNot() (newHost string) {
//newHost = req.Host
port := "80"
if req.IsHTTPS() {
port = "443"
}
if (!strings.HasPrefix(req.Host, "[") && strings.Index(req.Host, ":") == -1) || (strings.HasPrefix(req.Host, "[") && strings.HasSuffix(req.Host, "]")) {
//newHost = req.Host + ":" + port
//req.headBuf = []byte(strings.Replace(string(req.headBuf), req.Host, newHost, 1))
req.Host = req.Host + ":" + port
}
return
}
type OutPool struct {
Pool ConnPool
dur int
isTLS bool
certBytes []byte
keyBytes []byte
address string
timeout int
}
func NewOutPool(dur int, isTLS bool, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutPool) {
op = OutPool{
dur: dur,
isTLS: isTLS,
certBytes: certBytes,
keyBytes: keyBytes,
address: address,
timeout: timeout,
}
var err error
op.Pool, err = NewConnPool(poolConfig{
IsActive: func(conn interface{}) bool { return true },
Release: func(conn interface{}) {
if conn != nil {
conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
conn.(net.Conn).Close()
// log.Println("conn released")
}
},
InitialCap: InitialCap,
MaxCap: MaxCap,
Factory: func() (conn interface{}, err error) {
conn, err = op.getConn()
return
},
})
if err != nil {
log.Fatalf("init conn pool fail ,%s", err)
} else {
if InitialCap > 0 {
log.Printf("init conn pool success")
op.initPoolDeamon()
} else {
log.Printf("conn pool closed")
}
}
return
}
func (op *OutPool) getConn() (conn interface{}, err error) {
if op.isTLS {
var _conn tls.Conn
_conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes)
if err == nil {
conn = net.Conn(&_conn)
}
} else {
conn, err = ConnectHost(op.address, op.timeout)
}
return
}
func (op *OutPool) initPoolDeamon() {
go func() {
if op.dur <= 0 {
return
}
log.Printf("pool deamon started")
for {
time.Sleep(time.Second * time.Duration(op.dur))
conn, err := op.getConn()
if err != nil {
log.Printf("pool deamon err %s , release pool", err)
op.Pool.ReleaseAll()
} else {
conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
conn.(net.Conn).Close()
}
}
}()
}