mirror of
https://github.com/snail007/goproxy.git
synced 2026-04-02 02:38:18 +00:00
update
This commit is contained in:
363
utils/functions.go
Executable file
363
utils/functions.go
Executable file
@@ -0,0 +1,363 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(isSrcErr bool, err error), cfn func(count int, isPositive bool), bytesPreSec float64) {
|
||||
var one = &sync.Once{}
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
var err error
|
||||
var isSrcErr bool
|
||||
if bytesPreSec > 0 {
|
||||
newreader := NewReader(src)
|
||||
newreader.SetRateLimit(bytesPreSec)
|
||||
_, isSrcErr, err = ioCopy(dst, newreader, func(c int) {
|
||||
cfn(c, false)
|
||||
})
|
||||
|
||||
} else {
|
||||
_, isSrcErr, err = ioCopy(dst, src, func(c int) {
|
||||
cfn(c, false)
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
one.Do(func() {
|
||||
fn(isSrcErr, err)
|
||||
})
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
log.Printf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
var err error
|
||||
var isSrcErr bool
|
||||
if bytesPreSec > 0 {
|
||||
newReader := NewReader(dst)
|
||||
newReader.SetRateLimit(bytesPreSec)
|
||||
_, isSrcErr, err = ioCopy(src, newReader, func(c int) {
|
||||
cfn(c, true)
|
||||
})
|
||||
} else {
|
||||
_, isSrcErr, err = ioCopy(src, dst, func(c int) {
|
||||
cfn(c, true)
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
one.Do(func() {
|
||||
fn(isSrcErr, err)
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
func ioCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, isSrcErr bool, err error) {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := dst.Write(buf[0:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
if len(fn) == 1 {
|
||||
fn[0](nw)
|
||||
}
|
||||
}
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
if nr != nw {
|
||||
err = io.ErrShortWrite
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
err = er
|
||||
isSrcErr = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return written, isSrcErr, err
|
||||
}
|
||||
func TlsConnectHost(host string, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) {
|
||||
h := strings.Split(host, ":")
|
||||
port, _ := strconv.Atoi(h[1])
|
||||
return TlsConnect(h[0], port, timeout, certBytes, keyBytes)
|
||||
}
|
||||
|
||||
func TlsConnect(host string, port, timeout int, certBytes, keyBytes []byte) (conn tls.Conn, err error) {
|
||||
conf, err := getRequestTlsConfig(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Duration(timeout)*time.Millisecond)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return *tls.Client(_conn, conf), err
|
||||
}
|
||||
func getRequestTlsConfig(certBytes, keyBytes []byte) (conf *tls.Config, err error) {
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
serverCertPool := x509.NewCertPool()
|
||||
ok := serverCertPool.AppendCertsFromPEM(certBytes)
|
||||
if !ok {
|
||||
err = errors.New("failed to parse root certificate")
|
||||
}
|
||||
conf = &tls.Config{
|
||||
RootCAs: serverCertPool,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ServerName: "proxy",
|
||||
InsecureSkipVerify: false,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
|
||||
conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
|
||||
return
|
||||
}
|
||||
func ListenTls(ip string, port int, certBytes, keyBytes []byte) (ln *net.Listener, err error) {
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
clientCertPool := x509.NewCertPool()
|
||||
ok := clientCertPool.AppendCertsFromPEM(certBytes)
|
||||
if !ok {
|
||||
err = errors.New("failed to parse root certificate")
|
||||
}
|
||||
config := &tls.Config{
|
||||
ClientCAs: clientCertPool,
|
||||
ServerName: "proxy",
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
}
|
||||
_ln, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", ip, port), config)
|
||||
if err == nil {
|
||||
ln = &_ln
|
||||
}
|
||||
return
|
||||
}
|
||||
func PathExists(_path string) bool {
|
||||
_, err := os.Stat(_path)
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
func HTTPGet(URL string, timeout int) (err error) {
|
||||
tr := &http.Transport{}
|
||||
var resp *http.Response
|
||||
var client *http.Client
|
||||
defer func() {
|
||||
if resp != nil && resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
tr.CloseIdleConnections()
|
||||
}()
|
||||
client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr}
|
||||
resp, err = client.Get(URL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func CloseConn(conn *net.Conn) {
|
||||
if conn != nil && *conn != nil {
|
||||
(*conn).SetDeadline(time.Now().Add(time.Millisecond))
|
||||
(*conn).Close()
|
||||
}
|
||||
}
|
||||
func Keygen() (err error) {
|
||||
cmd := exec.Command("sh", "-c", "openssl genrsa -out proxy.key 2048")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Printf("err:%s", err)
|
||||
return
|
||||
}
|
||||
fmt.Println(string(out))
|
||||
cmd = exec.Command("sh", "-c", `openssl req -new -key proxy.key -x509 -days 3650 -out proxy.crt -subj /C=CN/ST=BJ/O="Localhost Ltd"/CN=proxy`)
|
||||
out, err = cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Printf("err:%s", err)
|
||||
return
|
||||
}
|
||||
fmt.Println(string(out))
|
||||
return
|
||||
}
|
||||
func GetAllInterfaceAddr() ([]net.IP, error) {
|
||||
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addresses := []net.IP{}
|
||||
for _, iface := range ifaces {
|
||||
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue // interface down
|
||||
}
|
||||
// if iface.Flags&net.FlagLoopback != 0 {
|
||||
// continue // loopback interface
|
||||
// }
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
}
|
||||
// if ip == nil || ip.IsLoopback() {
|
||||
// continue
|
||||
// }
|
||||
ip = ip.To4()
|
||||
if ip == nil {
|
||||
continue // not an ipv4 address
|
||||
}
|
||||
addresses = append(addresses, ip)
|
||||
}
|
||||
}
|
||||
if len(addresses) == 0 {
|
||||
return nil, fmt.Errorf("no address Found, net.InterfaceAddrs: %v", addresses)
|
||||
}
|
||||
//only need first
|
||||
return addresses, nil
|
||||
}
|
||||
func UDPPacket(srcAddr string, packet []byte) []byte {
|
||||
addrBytes := []byte(srcAddr)
|
||||
addrLength := uint16(len(addrBytes))
|
||||
bodyLength := uint16(len(packet))
|
||||
pkg := new(bytes.Buffer)
|
||||
binary.Write(pkg, binary.LittleEndian, addrLength)
|
||||
binary.Write(pkg, binary.LittleEndian, addrBytes)
|
||||
binary.Write(pkg, binary.LittleEndian, bodyLength)
|
||||
binary.Write(pkg, binary.LittleEndian, packet)
|
||||
return pkg.Bytes()
|
||||
}
|
||||
func ReadUDPPacket(conn *net.Conn) (srcAddr string, packet []byte, err error) {
|
||||
reader := bufio.NewReader(*conn)
|
||||
var addrLength uint16
|
||||
var bodyLength uint16
|
||||
err = binary.Read(reader, binary.LittleEndian, &addrLength)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_srcAddr := make([]byte, addrLength)
|
||||
n, err := reader.Read(_srcAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n != int(addrLength) {
|
||||
return
|
||||
}
|
||||
srcAddr = string(_srcAddr)
|
||||
|
||||
err = binary.Read(reader, binary.LittleEndian, &bodyLength)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
packet = make([]byte, bodyLength)
|
||||
n, err = reader.Read(packet)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n != int(bodyLength) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// type sockaddr struct {
|
||||
// family uint16
|
||||
// data [14]byte
|
||||
// }
|
||||
|
||||
// const SO_ORIGINAL_DST = 80
|
||||
|
||||
// realServerAddress returns an intercepted connection's original destination.
|
||||
// func realServerAddress(conn *net.Conn) (string, error) {
|
||||
// tcpConn, ok := (*conn).(*net.TCPConn)
|
||||
// if !ok {
|
||||
// return "", errors.New("not a TCPConn")
|
||||
// }
|
||||
|
||||
// file, err := tcpConn.File()
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
|
||||
// // To avoid potential problems from making the socket non-blocking.
|
||||
// tcpConn.Close()
|
||||
// *conn, err = net.FileConn(file)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
|
||||
// defer file.Close()
|
||||
// fd := file.Fd()
|
||||
|
||||
// var addr sockaddr
|
||||
// size := uint32(unsafe.Sizeof(addr))
|
||||
// err = getsockopt(int(fd), syscall.SOL_IP, SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&addr)), &size)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
|
||||
// var ip net.IP
|
||||
// switch addr.family {
|
||||
// case syscall.AF_INET:
|
||||
// ip = addr.data[2:6]
|
||||
// default:
|
||||
// return "", errors.New("unrecognized address family")
|
||||
// }
|
||||
|
||||
// port := int(addr.data[0])<<8 + int(addr.data[1])
|
||||
|
||||
// return net.JoinHostPort(ip.String(), strconv.Itoa(port)), nil
|
||||
// }
|
||||
|
||||
// func getsockopt(s int, level int, name int, val uintptr, vallen *uint32) (err error) {
|
||||
// _, _, e1 := syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(s), uintptr(level), uintptr(name), uintptr(val), uintptr(unsafe.Pointer(vallen)), 0)
|
||||
// if e1 != 0 {
|
||||
// err = e1
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
97
utils/io-limiter.go
Normal file
97
utils/io-limiter.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const burstLimit = 1000 * 1000 * 1000
|
||||
|
||||
type Reader struct {
|
||||
r io.Reader
|
||||
limiter *rate.Limiter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type Writer struct {
|
||||
w io.Writer
|
||||
limiter *rate.Limiter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewReader returns a reader that implements io.Reader with rate limiting.
|
||||
func NewReader(r io.Reader) *Reader {
|
||||
return &Reader{
|
||||
r: r,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewReaderWithContext returns a reader that implements io.Reader with rate limiting.
|
||||
func NewReaderWithContext(r io.Reader, ctx context.Context) *Reader {
|
||||
return &Reader{
|
||||
r: r,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWriter returns a writer that implements io.Writer with rate limiting.
|
||||
func NewWriter(w io.Writer) *Writer {
|
||||
return &Writer{
|
||||
w: w,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewWriterWithContext returns a writer that implements io.Writer with rate limiting.
|
||||
func NewWriterWithContext(w io.Writer, ctx context.Context) *Writer {
|
||||
return &Writer{
|
||||
w: w,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// SetRateLimit sets rate limit (bytes/sec) to the reader.
|
||||
func (s *Reader) SetRateLimit(bytesPerSec float64) {
|
||||
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
||||
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
||||
}
|
||||
|
||||
// Read reads bytes into p.
|
||||
func (s *Reader) Read(p []byte) (int, error) {
|
||||
if s.limiter == nil {
|
||||
return s.r.Read(p)
|
||||
}
|
||||
n, err := s.r.Read(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if err := s.limiter.WaitN(s.ctx, n); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// SetRateLimit sets rate limit (bytes/sec) to the writer.
|
||||
func (s *Writer) SetRateLimit(bytesPerSec float64) {
|
||||
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
||||
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
||||
}
|
||||
|
||||
// Write writes bytes from p.
|
||||
func (s *Writer) Write(p []byte) (int, error) {
|
||||
if s.limiter == nil {
|
||||
return s.w.Write(p)
|
||||
}
|
||||
n, err := s.w.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if err := s.limiter.WaitN(s.ctx, n); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
315
utils/map.go
Normal file
315
utils/map.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var SHARD_COUNT = 32
|
||||
|
||||
// A "thread" safe map of type string:Anything.
|
||||
// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
|
||||
type ConcurrentMap []*ConcurrentMapShared
|
||||
|
||||
// A "thread" safe string to anything map.
|
||||
type ConcurrentMapShared struct {
|
||||
items map[string]interface{}
|
||||
sync.RWMutex // Read Write mutex, guards access to internal map.
|
||||
}
|
||||
|
||||
// Creates a new concurrent map.
|
||||
func NewConcurrentMap() ConcurrentMap {
|
||||
m := make(ConcurrentMap, SHARD_COUNT)
|
||||
for i := 0; i < SHARD_COUNT; i++ {
|
||||
m[i] = &ConcurrentMapShared{items: make(map[string]interface{})}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GetShard returns shard under given key
|
||||
func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared {
|
||||
return m[uint(fnv32(key))%uint(SHARD_COUNT)]
|
||||
}
|
||||
|
||||
func (m ConcurrentMap) MSet(data map[string]interface{}) {
|
||||
for key, value := range data {
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
shard.items[key] = value
|
||||
shard.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Sets the given value under the specified key.
|
||||
func (m ConcurrentMap) Set(key string, value interface{}) {
|
||||
// Get map shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
shard.items[key] = value
|
||||
shard.Unlock()
|
||||
}
|
||||
|
||||
// Callback to return new element to be inserted into the map
|
||||
// It is called while lock is held, therefore it MUST NOT
|
||||
// try to access other keys in same map, as it can lead to deadlock since
|
||||
// Go sync.RWLock is not reentrant
|
||||
type UpsertCb func(exist bool, valueInMap interface{}, newValue interface{}) interface{}
|
||||
|
||||
// Insert or Update - updates existing element or inserts a new one using UpsertCb
|
||||
func (m ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res interface{}) {
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
v, ok := shard.items[key]
|
||||
res = cb(ok, v, value)
|
||||
shard.items[key] = res
|
||||
shard.Unlock()
|
||||
return res
|
||||
}
|
||||
|
||||
// Sets the given value under the specified key if no value was associated with it.
|
||||
func (m ConcurrentMap) SetIfAbsent(key string, value interface{}) bool {
|
||||
// Get map shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
_, ok := shard.items[key]
|
||||
if !ok {
|
||||
shard.items[key] = value
|
||||
}
|
||||
shard.Unlock()
|
||||
return !ok
|
||||
}
|
||||
|
||||
// Get retrieves an element from map under given key.
|
||||
func (m ConcurrentMap) Get(key string) (interface{}, bool) {
|
||||
// Get shard
|
||||
shard := m.GetShard(key)
|
||||
shard.RLock()
|
||||
// Get item from shard.
|
||||
val, ok := shard.items[key]
|
||||
shard.RUnlock()
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Count returns the number of elements within the map.
|
||||
func (m ConcurrentMap) Count() int {
|
||||
count := 0
|
||||
for i := 0; i < SHARD_COUNT; i++ {
|
||||
shard := m[i]
|
||||
shard.RLock()
|
||||
count += len(shard.items)
|
||||
shard.RUnlock()
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Looks up an item under specified key
|
||||
func (m ConcurrentMap) Has(key string) bool {
|
||||
// Get shard
|
||||
shard := m.GetShard(key)
|
||||
shard.RLock()
|
||||
// See if element is within shard.
|
||||
_, ok := shard.items[key]
|
||||
shard.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
// Remove removes an element from the map.
|
||||
func (m ConcurrentMap) Remove(key string) {
|
||||
// Try to get shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
delete(shard.items, key)
|
||||
shard.Unlock()
|
||||
}
|
||||
|
||||
// Pop removes an element from the map and returns it
|
||||
func (m ConcurrentMap) Pop(key string) (v interface{}, exists bool) {
|
||||
// Try to get shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
v, exists = shard.items[key]
|
||||
delete(shard.items, key)
|
||||
shard.Unlock()
|
||||
return v, exists
|
||||
}
|
||||
|
||||
// IsEmpty checks if map is empty.
|
||||
func (m ConcurrentMap) IsEmpty() bool {
|
||||
return m.Count() == 0
|
||||
}
|
||||
|
||||
// Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
|
||||
type Tuple struct {
|
||||
Key string
|
||||
Val interface{}
|
||||
}
|
||||
|
||||
// Iter returns an iterator which could be used in a for range loop.
|
||||
//
|
||||
// Deprecated: using IterBuffered() will get a better performence
|
||||
func (m ConcurrentMap) Iter() <-chan Tuple {
|
||||
chans := snapshot(m)
|
||||
ch := make(chan Tuple)
|
||||
go fanIn(chans, ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
// IterBuffered returns a buffered iterator which could be used in a for range loop.
|
||||
func (m ConcurrentMap) IterBuffered() <-chan Tuple {
|
||||
chans := snapshot(m)
|
||||
total := 0
|
||||
for _, c := range chans {
|
||||
total += cap(c)
|
||||
}
|
||||
ch := make(chan Tuple, total)
|
||||
go fanIn(chans, ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
// Returns a array of channels that contains elements in each shard,
|
||||
// which likely takes a snapshot of `m`.
|
||||
// It returns once the size of each buffered channel is determined,
|
||||
// before all the channels are populated using goroutines.
|
||||
func snapshot(m ConcurrentMap) (chans []chan Tuple) {
|
||||
chans = make([]chan Tuple, SHARD_COUNT)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(SHARD_COUNT)
|
||||
// Foreach shard.
|
||||
for index, shard := range m {
|
||||
go func(index int, shard *ConcurrentMapShared) {
|
||||
// Foreach key, value pair.
|
||||
shard.RLock()
|
||||
chans[index] = make(chan Tuple, len(shard.items))
|
||||
wg.Done()
|
||||
for key, val := range shard.items {
|
||||
chans[index] <- Tuple{key, val}
|
||||
}
|
||||
shard.RUnlock()
|
||||
close(chans[index])
|
||||
}(index, shard)
|
||||
}
|
||||
wg.Wait()
|
||||
return chans
|
||||
}
|
||||
|
||||
// fanIn reads elements from channels `chans` into channel `out`
|
||||
func fanIn(chans []chan Tuple, out chan Tuple) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(chans))
|
||||
for _, ch := range chans {
|
||||
go func(ch chan Tuple) {
|
||||
for t := range ch {
|
||||
out <- t
|
||||
}
|
||||
wg.Done()
|
||||
}(ch)
|
||||
}
|
||||
wg.Wait()
|
||||
close(out)
|
||||
}
|
||||
|
||||
// Items returns all items as map[string]interface{}
|
||||
func (m ConcurrentMap) Items() map[string]interface{} {
|
||||
tmp := make(map[string]interface{})
|
||||
|
||||
// Insert items to temporary map.
|
||||
for item := range m.IterBuffered() {
|
||||
tmp[item.Key] = item.Val
|
||||
}
|
||||
|
||||
return tmp
|
||||
}
|
||||
|
||||
// Iterator callback,called for every key,value found in
|
||||
// maps. RLock is held for all calls for a given shard
|
||||
// therefore callback sess consistent view of a shard,
|
||||
// but not across the shards
|
||||
type IterCb func(key string, v interface{})
|
||||
|
||||
// Callback based iterator, cheapest way to read
|
||||
// all elements in a map.
|
||||
func (m ConcurrentMap) IterCb(fn IterCb) {
|
||||
for idx := range m {
|
||||
shard := (m)[idx]
|
||||
shard.RLock()
|
||||
for key, value := range shard.items {
|
||||
fn(key, value)
|
||||
}
|
||||
shard.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Keys returns all keys as []string
|
||||
func (m ConcurrentMap) Keys() []string {
|
||||
count := m.Count()
|
||||
ch := make(chan string, count)
|
||||
go func() {
|
||||
// Foreach shard.
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(SHARD_COUNT)
|
||||
for _, shard := range m {
|
||||
go func(shard *ConcurrentMapShared) {
|
||||
// Foreach key, value pair.
|
||||
shard.RLock()
|
||||
for key := range shard.items {
|
||||
ch <- key
|
||||
}
|
||||
shard.RUnlock()
|
||||
wg.Done()
|
||||
}(shard)
|
||||
}
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
// Generate keys
|
||||
keys := make([]string, 0, count)
|
||||
for k := range ch {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
//Reviles ConcurrentMap "private" variables to json marshal.
|
||||
func (m ConcurrentMap) MarshalJSON() ([]byte, error) {
|
||||
// Create a temporary map, which will hold all item spread across shards.
|
||||
tmp := make(map[string]interface{})
|
||||
|
||||
// Insert items to temporary map.
|
||||
for item := range m.IterBuffered() {
|
||||
tmp[item.Key] = item.Val
|
||||
}
|
||||
return json.Marshal(tmp)
|
||||
}
|
||||
|
||||
func fnv32(key string) uint32 {
|
||||
hash := uint32(2166136261)
|
||||
const prime32 = uint32(16777619)
|
||||
for i := 0; i < len(key); i++ {
|
||||
hash *= prime32
|
||||
hash ^= uint32(key[i])
|
||||
}
|
||||
return hash
|
||||
}
|
||||
|
||||
// Concurrent map uses Interface{} as its value, therefor JSON Unmarshal
|
||||
// will probably won't know which to type to unmarshal into, in such case
|
||||
// we'll end up with a value of type map[string]interface{}, In most cases this isn't
|
||||
// out value type, this is why we've decided to remove this functionality.
|
||||
|
||||
// func (m *ConcurrentMap) UnmarshalJSON(b []byte) (err error) {
|
||||
// // Reverse process of Marshal.
|
||||
|
||||
// tmp := make(map[string]interface{})
|
||||
|
||||
// // Unmarshal into a single map.
|
||||
// if err := json.Unmarshal(b, &tmp); err != nil {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // foreach key,value pair in temporary map insert into our concurrent map.
|
||||
// for key, val := range tmp {
|
||||
// m.Set(key, val)
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
145
utils/pool.go
Executable file
145
utils/pool.go
Executable file
@@ -0,0 +1,145 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
//ConnPool to use
|
||||
type ConnPool interface {
|
||||
Get() (conn interface{}, err error)
|
||||
Put(conn interface{})
|
||||
ReleaseAll()
|
||||
Len() (length int)
|
||||
}
|
||||
type poolConfig struct {
|
||||
Factory func() (interface{}, error)
|
||||
IsActive func(interface{}) bool
|
||||
Release func(interface{})
|
||||
InitialCap int
|
||||
MaxCap int
|
||||
}
|
||||
|
||||
func NewConnPool(poolConfig poolConfig) (pool ConnPool, err error) {
|
||||
p := netPool{
|
||||
config: poolConfig,
|
||||
conns: make(chan interface{}, poolConfig.MaxCap),
|
||||
lock: &sync.Mutex{},
|
||||
}
|
||||
//log.Printf("pool MaxCap:%d", poolConfig.MaxCap)
|
||||
if poolConfig.MaxCap > 0 {
|
||||
err = p.initAutoFill(false)
|
||||
if err == nil {
|
||||
p.initAutoFill(true)
|
||||
}
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
type netPool struct {
|
||||
conns chan interface{}
|
||||
lock *sync.Mutex
|
||||
config poolConfig
|
||||
}
|
||||
|
||||
func (p *netPool) initAutoFill(async bool) (err error) {
|
||||
var worker = func() (err error) {
|
||||
for {
|
||||
//log.Printf("pool fill: %v , len: %d", p.Len() <= p.config.InitialCap/2, p.Len())
|
||||
if p.Len() <= p.config.InitialCap/2 {
|
||||
p.lock.Lock()
|
||||
errN := 0
|
||||
for i := 0; i < p.config.InitialCap; i++ {
|
||||
c, err := p.config.Factory()
|
||||
if err != nil {
|
||||
errN++
|
||||
if async {
|
||||
continue
|
||||
} else {
|
||||
p.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
}
|
||||
select {
|
||||
case p.conns <- c:
|
||||
default:
|
||||
p.config.Release(c)
|
||||
break
|
||||
}
|
||||
if p.Len() >= p.config.InitialCap {
|
||||
break
|
||||
}
|
||||
}
|
||||
if errN > 0 {
|
||||
log.Printf("fill conn pool fail , ERRN:%d", errN)
|
||||
}
|
||||
p.lock.Unlock()
|
||||
}
|
||||
if !async {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Second * 2)
|
||||
}
|
||||
}
|
||||
if async {
|
||||
go worker()
|
||||
} else {
|
||||
err = worker()
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (p *netPool) Get() (conn interface{}, err error) {
|
||||
// defer func() {
|
||||
// log.Printf("pool len : %d", p.Len())
|
||||
// }()
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
// for {
|
||||
select {
|
||||
case conn = <-p.conns:
|
||||
if p.config.IsActive(conn) {
|
||||
return
|
||||
}
|
||||
p.config.Release(conn)
|
||||
default:
|
||||
conn, err = p.config.Factory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
// }
|
||||
return
|
||||
}
|
||||
|
||||
func (p *netPool) Put(conn interface{}) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
if !p.config.IsActive(conn) {
|
||||
p.config.Release(conn)
|
||||
}
|
||||
select {
|
||||
case p.conns <- conn:
|
||||
default:
|
||||
p.config.Release(conn)
|
||||
}
|
||||
}
|
||||
func (p *netPool) ReleaseAll() {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
close(p.conns)
|
||||
for c := range p.conns {
|
||||
p.config.Release(c)
|
||||
}
|
||||
p.conns = make(chan interface{}, p.config.InitialCap)
|
||||
|
||||
}
|
||||
func (p *netPool) Len() (length int) {
|
||||
return len(p.conns)
|
||||
}
|
||||
126
utils/serve-channel.go
Normal file
126
utils/serve-channel.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
type ServerChannel struct {
|
||||
ip string
|
||||
port int
|
||||
Listener *net.Listener
|
||||
UDPListener *net.UDPConn
|
||||
errAcceptHandler func(err error)
|
||||
}
|
||||
|
||||
func NewServerChannel(ip string, port int) ServerChannel {
|
||||
return ServerChannel{
|
||||
ip: ip,
|
||||
port: port,
|
||||
errAcceptHandler: func(err error) {
|
||||
fmt.Printf("accept error , ERR:%s", err)
|
||||
},
|
||||
}
|
||||
}
|
||||
func (sc *ServerChannel) SetErrAcceptHandler(fn func(err error)) {
|
||||
sc.errAcceptHandler = fn
|
||||
}
|
||||
func (sc *ServerChannel) ListenTls(certBytes, keyBytes []byte, fn func(conn net.Conn)) (err error) {
|
||||
sc.Listener, err = ListenTls(sc.ip, sc.port, certBytes, keyBytes)
|
||||
if err == nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
log.Printf("ListenTls crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
var conn net.Conn
|
||||
conn, err = (*sc.Listener).Accept()
|
||||
if err == nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
log.Printf("connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
fn(conn)
|
||||
}()
|
||||
} else {
|
||||
sc.errAcceptHandler(err)
|
||||
(*sc.Listener).Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (sc *ServerChannel) ListenTCP(fn func(conn net.Conn)) (err error) {
|
||||
var l net.Listener
|
||||
l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sc.ip, sc.port))
|
||||
if err == nil {
|
||||
sc.Listener = &l
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
log.Printf("ListenTCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
var conn net.Conn
|
||||
conn, err = (*sc.Listener).Accept()
|
||||
if err == nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
log.Printf("connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
fn(conn)
|
||||
}()
|
||||
} else {
|
||||
sc.errAcceptHandler(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return
|
||||
}
|
||||
func (sc *ServerChannel) ListenUDP(fn func(packet []byte, localAddr, srcAddr *net.UDPAddr)) (err error) {
|
||||
addr := &net.UDPAddr{IP: net.ParseIP(sc.ip), Port: sc.port}
|
||||
l, err := net.ListenUDP("udp", addr)
|
||||
if err == nil {
|
||||
sc.UDPListener = l
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
log.Printf("ListenUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
var buf = make([]byte, 2048)
|
||||
n, srcAddr, err := (*sc.UDPListener).ReadFromUDP(buf)
|
||||
if err == nil {
|
||||
packet := buf[0:n]
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
log.Printf("udp data handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
fn(packet, addr, srcAddr)
|
||||
}()
|
||||
} else {
|
||||
sc.errAcceptHandler(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return
|
||||
}
|
||||
462
utils/structs.go
Normal file
462
utils/structs.go
Normal file
@@ -0,0 +1,462 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Checker struct {
|
||||
data ConcurrentMap
|
||||
blockedMap ConcurrentMap
|
||||
directMap ConcurrentMap
|
||||
interval int64
|
||||
timeout int
|
||||
}
|
||||
type CheckerItem struct {
|
||||
IsHTTPS bool
|
||||
Method string
|
||||
URL string
|
||||
Domain string
|
||||
Host string
|
||||
Data []byte
|
||||
SuccessCount uint
|
||||
FailCount uint
|
||||
}
|
||||
|
||||
//NewChecker args:
|
||||
//timeout : tcp timeout milliseconds ,connect to host
|
||||
//interval: recheck domain interval seconds
|
||||
func NewChecker(timeout int, interval int64, blockedFile, directFile string) Checker {
|
||||
ch := Checker{
|
||||
data: NewConcurrentMap(),
|
||||
interval: interval,
|
||||
timeout: timeout,
|
||||
}
|
||||
ch.blockedMap = ch.loadMap(blockedFile)
|
||||
ch.directMap = ch.loadMap(directFile)
|
||||
if !ch.blockedMap.IsEmpty() {
|
||||
log.Printf("blocked file loaded , domains : %d", ch.blockedMap.Count())
|
||||
}
|
||||
if !ch.directMap.IsEmpty() {
|
||||
log.Printf("direct file loaded , domains : %d", ch.directMap.Count())
|
||||
}
|
||||
ch.start()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (c *Checker) loadMap(f string) (dataMap ConcurrentMap) {
|
||||
dataMap = NewConcurrentMap()
|
||||
if PathExists(f) {
|
||||
_contents, err := ioutil.ReadFile(f)
|
||||
if err != nil {
|
||||
log.Printf("load file err:%s", err)
|
||||
return
|
||||
}
|
||||
for _, line := range strings.Split(string(_contents), "\n") {
|
||||
line = strings.Trim(line, "\r \t")
|
||||
if line != "" {
|
||||
dataMap.Set(line, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func (c *Checker) start() {
|
||||
go func() {
|
||||
for {
|
||||
for _, v := range c.data.Items() {
|
||||
go func(item CheckerItem) {
|
||||
if c.isNeedCheck(item) {
|
||||
//log.Printf("check %s", item.Domain)
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if item.IsHTTPS {
|
||||
conn, err = ConnectHost(item.Host, c.timeout)
|
||||
if err == nil {
|
||||
conn.SetDeadline(time.Now().Add(time.Millisecond))
|
||||
conn.Close()
|
||||
}
|
||||
} else {
|
||||
err = HTTPGet(item.URL, c.timeout)
|
||||
}
|
||||
if err != nil {
|
||||
item.FailCount = item.FailCount + 1
|
||||
} else {
|
||||
item.SuccessCount = item.SuccessCount + 1
|
||||
}
|
||||
c.data.Set(item.Host, item)
|
||||
}
|
||||
}(v.(CheckerItem))
|
||||
}
|
||||
time.Sleep(time.Second * time.Duration(c.interval))
|
||||
}
|
||||
}()
|
||||
}
|
||||
func (c *Checker) isNeedCheck(item CheckerItem) bool {
|
||||
var minCount uint = 5
|
||||
if (item.SuccessCount >= minCount && item.SuccessCount > item.FailCount) ||
|
||||
(item.FailCount >= minCount && item.SuccessCount > item.FailCount) ||
|
||||
c.domainIsInMap(item.Host, false) ||
|
||||
c.domainIsInMap(item.Host, true) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
func (c *Checker) IsBlocked(address string) (blocked bool, failN, successN uint) {
|
||||
if c.domainIsInMap(address, true) {
|
||||
//log.Printf("%s in blocked ? true", address)
|
||||
return true, 0, 0
|
||||
}
|
||||
if c.domainIsInMap(address, false) {
|
||||
//log.Printf("%s in direct ? true", address)
|
||||
return false, 0, 0
|
||||
}
|
||||
|
||||
_item, ok := c.data.Get(address)
|
||||
if !ok {
|
||||
//log.Printf("%s not in map, blocked true", address)
|
||||
return true, 0, 0
|
||||
}
|
||||
item := _item.(CheckerItem)
|
||||
|
||||
return item.FailCount >= item.SuccessCount, item.FailCount, item.SuccessCount
|
||||
}
|
||||
func (c *Checker) domainIsInMap(address string, blockedMap bool) bool {
|
||||
u, err := url.Parse("http://" + address)
|
||||
if err != nil {
|
||||
log.Printf("blocked check , url parse err:%s", err)
|
||||
return true
|
||||
}
|
||||
domainSlice := strings.Split(u.Hostname(), ".")
|
||||
if len(domainSlice) > 1 {
|
||||
subSlice := domainSlice[:len(domainSlice)-1]
|
||||
topDomain := strings.Join(domainSlice[len(domainSlice)-1:], ".")
|
||||
checkDomain := topDomain
|
||||
for i := len(subSlice) - 1; i >= 0; i-- {
|
||||
checkDomain = subSlice[i] + "." + checkDomain
|
||||
if !blockedMap && c.directMap.Has(checkDomain) {
|
||||
return true
|
||||
}
|
||||
if blockedMap && c.blockedMap.Has(checkDomain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (c *Checker) Add(address string, isHTTPS bool, method, URL string, data []byte) {
|
||||
if c.domainIsInMap(address, false) || c.domainIsInMap(address, true) {
|
||||
return
|
||||
}
|
||||
if !isHTTPS && strings.ToLower(method) != "get" {
|
||||
return
|
||||
}
|
||||
var item CheckerItem
|
||||
u := strings.Split(address, ":")
|
||||
item = CheckerItem{
|
||||
URL: URL,
|
||||
Domain: u[0],
|
||||
Host: address,
|
||||
Data: data,
|
||||
IsHTTPS: isHTTPS,
|
||||
Method: method,
|
||||
}
|
||||
c.data.SetIfAbsent(item.Host, item)
|
||||
}
|
||||
|
||||
type BasicAuth struct {
|
||||
data ConcurrentMap
|
||||
}
|
||||
|
||||
func NewBasicAuth() BasicAuth {
|
||||
return BasicAuth{
|
||||
data: NewConcurrentMap(),
|
||||
}
|
||||
}
|
||||
func (ba *BasicAuth) AddFromFile(file string) (n int, err error) {
|
||||
_content, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
userpassArr := strings.Split(strings.Replace(string(_content), "\r", "", -1), "\n")
|
||||
for _, userpass := range userpassArr {
|
||||
if strings.HasPrefix("#", userpass) {
|
||||
continue
|
||||
}
|
||||
u := strings.Split(strings.Trim(userpass, " "), ":")
|
||||
if len(u) == 2 {
|
||||
ba.data.Set(u[0], u[1])
|
||||
n++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ba *BasicAuth) Add(userpassArr []string) (n int) {
|
||||
for _, userpass := range userpassArr {
|
||||
u := strings.Split(userpass, ":")
|
||||
if len(u) == 2 {
|
||||
ba.data.Set(u[0], u[1])
|
||||
n++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ba *BasicAuth) Check(userpass string) (ok bool) {
|
||||
u := strings.Split(strings.Trim(userpass, " "), ":")
|
||||
if len(u) == 2 {
|
||||
if p, _ok := ba.data.Get(u[0]); _ok {
|
||||
return p.(string) == u[1]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func (ba *BasicAuth) Total() (n int) {
|
||||
n = ba.data.Count()
|
||||
return
|
||||
}
|
||||
|
||||
type HTTPRequest struct {
|
||||
HeadBuf []byte
|
||||
conn *net.Conn
|
||||
Host string
|
||||
Method string
|
||||
URL string
|
||||
hostOrURL string
|
||||
isBasicAuth bool
|
||||
basicAuth *BasicAuth
|
||||
}
|
||||
|
||||
func NewHTTPRequest(inConn *net.Conn, bufSize int, isBasicAuth bool, basicAuth *BasicAuth) (req HTTPRequest, err error) {
|
||||
buf := make([]byte, bufSize)
|
||||
len := 0
|
||||
req = HTTPRequest{
|
||||
conn: inConn,
|
||||
}
|
||||
len, err = (*inConn).Read(buf[:])
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
err = fmt.Errorf("http decoder read err:%s", err)
|
||||
}
|
||||
CloseConn(inConn)
|
||||
return
|
||||
}
|
||||
req.HeadBuf = buf[:len]
|
||||
index := bytes.IndexByte(req.HeadBuf, '\n')
|
||||
if index == -1 {
|
||||
err = fmt.Errorf("http decoder data line err:%s", string(req.HeadBuf)[:50])
|
||||
CloseConn(inConn)
|
||||
return
|
||||
}
|
||||
fmt.Sscanf(string(req.HeadBuf[:index]), "%s%s", &req.Method, &req.hostOrURL)
|
||||
if req.Method == "" || req.hostOrURL == "" {
|
||||
err = fmt.Errorf("http decoder data err:%s", string(req.HeadBuf)[:50])
|
||||
CloseConn(inConn)
|
||||
return
|
||||
}
|
||||
req.Method = strings.ToUpper(req.Method)
|
||||
req.isBasicAuth = isBasicAuth
|
||||
req.basicAuth = basicAuth
|
||||
log.Printf("%s:%s", req.Method, req.hostOrURL)
|
||||
|
||||
if req.IsHTTPS() {
|
||||
err = req.HTTPS()
|
||||
} else {
|
||||
err = req.HTTP()
|
||||
}
|
||||
return
|
||||
}
|
||||
func (req *HTTPRequest) HTTP() (err error) {
|
||||
if req.isBasicAuth {
|
||||
err = req.BasicAuth()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
req.URL, err = req.getHTTPURL()
|
||||
if err == nil {
|
||||
u, _ := url.Parse(req.URL)
|
||||
req.Host = u.Host
|
||||
req.addPortIfNot()
|
||||
}
|
||||
return
|
||||
}
|
||||
func (req *HTTPRequest) HTTPS() (err error) {
|
||||
req.Host = req.hostOrURL
|
||||
req.addPortIfNot()
|
||||
//_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
|
||||
return
|
||||
}
|
||||
func (req *HTTPRequest) HTTPSReply() (err error) {
|
||||
_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
|
||||
return
|
||||
}
|
||||
func (req *HTTPRequest) IsHTTPS() bool {
|
||||
return req.Method == "CONNECT"
|
||||
}
|
||||
|
||||
func (req *HTTPRequest) BasicAuth() (err error) {
|
||||
|
||||
//log.Printf("request :%s", string(b[:n]))
|
||||
authorization, err := req.getHeader("Authorization")
|
||||
if err != nil {
|
||||
fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized")
|
||||
CloseConn(req.conn)
|
||||
return
|
||||
}
|
||||
//log.Printf("Authorization:%s", authorization)
|
||||
basic := strings.Fields(authorization)
|
||||
if len(basic) != 2 {
|
||||
err = fmt.Errorf("authorization data error,ERR:%s", authorization)
|
||||
CloseConn(req.conn)
|
||||
return
|
||||
}
|
||||
user, err := base64.StdEncoding.DecodeString(basic[1])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("authorization data parse error,ERR:%s", err)
|
||||
CloseConn(req.conn)
|
||||
return
|
||||
}
|
||||
authOk := (*req.basicAuth).Check(string(user))
|
||||
//log.Printf("auth %s,%v", string(user), authOk)
|
||||
if !authOk {
|
||||
fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized")
|
||||
CloseConn(req.conn)
|
||||
err = fmt.Errorf("basic auth fail")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
func (req *HTTPRequest) getHTTPURL() (URL string, err error) {
|
||||
if !strings.HasPrefix(req.hostOrURL, "/") {
|
||||
return req.hostOrURL, nil
|
||||
}
|
||||
_host, err := req.getHeader("host")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
URL = fmt.Sprintf("http://%s%s", _host, req.hostOrURL)
|
||||
return
|
||||
}
|
||||
func (req *HTTPRequest) getHeader(key string) (val string, err error) {
|
||||
key = strings.ToUpper(key)
|
||||
lines := strings.Split(string(req.HeadBuf), "\r\n")
|
||||
for _, line := range lines {
|
||||
line := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2)
|
||||
if len(line) == 2 {
|
||||
k := strings.ToUpper(strings.Trim(line[0], " "))
|
||||
v := strings.Trim(line[1], " ")
|
||||
if key == k {
|
||||
val = v
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
err = fmt.Errorf("can not find HOST header")
|
||||
return
|
||||
}
|
||||
|
||||
func (req *HTTPRequest) addPortIfNot() (newHost string) {
|
||||
//newHost = req.Host
|
||||
port := "80"
|
||||
if req.IsHTTPS() {
|
||||
port = "443"
|
||||
}
|
||||
if (!strings.HasPrefix(req.Host, "[") && strings.Index(req.Host, ":") == -1) || (strings.HasPrefix(req.Host, "[") && strings.HasSuffix(req.Host, "]")) {
|
||||
//newHost = req.Host + ":" + port
|
||||
//req.headBuf = []byte(strings.Replace(string(req.headBuf), req.Host, newHost, 1))
|
||||
req.Host = req.Host + ":" + port
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type OutPool struct {
|
||||
Pool ConnPool
|
||||
dur int
|
||||
isTLS bool
|
||||
certBytes []byte
|
||||
keyBytes []byte
|
||||
address string
|
||||
timeout int
|
||||
}
|
||||
|
||||
func NewOutPool(dur int, isTLS bool, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutPool) {
|
||||
op = OutPool{
|
||||
dur: dur,
|
||||
isTLS: isTLS,
|
||||
certBytes: certBytes,
|
||||
keyBytes: keyBytes,
|
||||
address: address,
|
||||
timeout: timeout,
|
||||
}
|
||||
var err error
|
||||
op.Pool, err = NewConnPool(poolConfig{
|
||||
IsActive: func(conn interface{}) bool { return true },
|
||||
Release: func(conn interface{}) {
|
||||
if conn != nil {
|
||||
conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
|
||||
conn.(net.Conn).Close()
|
||||
// log.Println("conn released")
|
||||
}
|
||||
},
|
||||
InitialCap: InitialCap,
|
||||
MaxCap: MaxCap,
|
||||
Factory: func() (conn interface{}, err error) {
|
||||
conn, err = op.getConn()
|
||||
return
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("init conn pool fail ,%s", err)
|
||||
} else {
|
||||
if InitialCap > 0 {
|
||||
log.Printf("init conn pool success")
|
||||
op.initPoolDeamon()
|
||||
} else {
|
||||
log.Printf("conn pool closed")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func (op *OutPool) getConn() (conn interface{}, err error) {
|
||||
if op.isTLS {
|
||||
var _conn tls.Conn
|
||||
_conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes)
|
||||
if err == nil {
|
||||
conn = net.Conn(&_conn)
|
||||
}
|
||||
} else {
|
||||
conn, err = ConnectHost(op.address, op.timeout)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (op *OutPool) initPoolDeamon() {
|
||||
go func() {
|
||||
if op.dur <= 0 {
|
||||
return
|
||||
}
|
||||
log.Printf("pool deamon started")
|
||||
for {
|
||||
time.Sleep(time.Second * time.Duration(op.dur))
|
||||
conn, err := op.getConn()
|
||||
if err != nil {
|
||||
log.Printf("pool deamon err %s , release pool", err)
|
||||
op.Pool.ReleaseAll()
|
||||
} else {
|
||||
conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
|
||||
conn.(net.Conn).Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
Reference in New Issue
Block a user