From 5c8b4e3ab9718781119968a361642cf51f3f3a69 Mon Sep 17 00:00:00 2001 From: gulewin Date: Tue, 14 Apr 2020 19:34:15 -0700 Subject: [PATCH] feat: allow multiple duplicate proxies registered with tcpmux for load balancing --- server/service.go | 2 +- utils/tcpmux/httpconnect.go | 2 +- utils/vhost/http.go | 11 +++--- utils/vhost/https.go | 2 +- utils/vhost/router.go | 74 +++++++++++++++++++++++++++++++------ utils/vhost/vhost.go | 10 ++--- 6 files changed, 76 insertions(+), 25 deletions(-) diff --git a/server/service.go b/server/service.go index d3c31699..0a9724c7 100644 --- a/server/service.go +++ b/server/service.go @@ -108,7 +108,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { TcpPortManager: ports.NewPortManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts), UdpPortManager: ports.NewPortManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts), }, - httpVhostRouter: vhost.NewVhostRouters(), + httpVhostRouter: vhost.NewVhostRouters(false), authVerifier: auth.NewAuthVerifier(cfg.AuthServerConfig), tlsConfig: generateTLSConfig(), cfg: cfg, diff --git a/utils/tcpmux/httpconnect.go b/utils/tcpmux/httpconnect.go index af0a39f9..1acc4b3c 100644 --- a/utils/tcpmux/httpconnect.go +++ b/utils/tcpmux/httpconnect.go @@ -31,7 +31,7 @@ type HttpConnectTcpMuxer struct { } func NewHttpConnectTcpMuxer(listener net.Listener, timeout time.Duration) (*HttpConnectTcpMuxer, error) { - mux, err := vhost.NewVhostMuxer(listener, getHostFromHttpConnect, nil, sendHttpOk, nil, timeout) + mux, err := vhost.NewVhostMuxer(listener, getHostFromHttpConnect, nil, sendHttpOk, nil, timeout, true) return &HttpConnectTcpMuxer{mux}, err } diff --git a/utils/vhost/http.go b/utils/vhost/http.go index 1bf6cc30..c46f3fe6 100644 --- a/utils/vhost/http.go +++ b/utils/vhost/http.go @@ -110,7 +110,7 @@ func (rp *HttpReverseProxy) UnRegister(domain string, location string) { func (rp *HttpReverseProxy) GetRealHost(domain string, location string) (host string) { vr, ok := rp.getVhost(domain, location) if ok { - host = vr.payload.(*VhostRouteConfig).RewriteHost + host = vr.getPayload().(*VhostRouteConfig).RewriteHost } return } @@ -118,7 +118,7 @@ func (rp *HttpReverseProxy) GetRealHost(domain string, location string) (host st func (rp *HttpReverseProxy) GetHeaders(domain string, location string) (headers map[string]string) { vr, ok := rp.getVhost(domain, location) if ok { - headers = vr.payload.(*VhostRouteConfig).Headers + headers = vr.getPayload().(*VhostRouteConfig).Headers } return } @@ -127,7 +127,7 @@ func (rp *HttpReverseProxy) GetHeaders(domain string, location string) (headers func (rp *HttpReverseProxy) CreateConnection(domain string, location string, remoteAddr string) (net.Conn, error) { vr, ok := rp.getVhost(domain, location) if ok { - fn := vr.payload.(*VhostRouteConfig).CreateConnFn + fn := vr.getPayload().(*VhostRouteConfig).CreateConnFn if fn != nil { return fn(remoteAddr) } @@ -138,8 +138,9 @@ func (rp *HttpReverseProxy) CreateConnection(domain string, location string, rem func (rp *HttpReverseProxy) CheckAuth(domain, location, user, passwd string) bool { vr, ok := rp.getVhost(domain, location) if ok { - checkUser := vr.payload.(*VhostRouteConfig).Username - checkPasswd := vr.payload.(*VhostRouteConfig).Password + routeCfg := vr.getPayload().(*VhostRouteConfig) + checkUser := routeCfg.Username + checkPasswd := routeCfg.Password if (checkUser != "" || checkPasswd != "") && (checkUser != user || checkPasswd != passwd) { return false } diff --git a/utils/vhost/https.go b/utils/vhost/https.go index 41bd01ce..46b7b525 100644 --- a/utils/vhost/https.go +++ b/utils/vhost/https.go @@ -48,7 +48,7 @@ type HttpsMuxer struct { } func NewHttpsMuxer(listener net.Listener, timeout time.Duration) (*HttpsMuxer, error) { - mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, nil, nil, timeout) + mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, nil, nil, timeout, false) return &HttpsMuxer{mux}, err } diff --git a/utils/vhost/router.go b/utils/vhost/router.go index bfdcb50b..90beedd2 100644 --- a/utils/vhost/router.go +++ b/utils/vhost/router.go @@ -2,6 +2,7 @@ package vhost import ( "errors" + "math/rand" "sort" "strings" "sync" @@ -12,20 +13,30 @@ var ( ) type VhostRouters struct { - RouterByDomain map[string][]*VhostRouter - mutex sync.RWMutex + RouterByDomain map[string][]*VhostRouter + allowDuplicates bool + mutex sync.RWMutex } type VhostRouter struct { domain string location string - payload interface{} + allowDuplicates bool + payloads []interface{} } -func NewVhostRouters() *VhostRouters { +func (vr *VhostRouter) getPayload() interface{} { + if !vr.allowDuplicates { + return vr.payloads[0] + } + return vr.payloads[rand.Intn(len(vr.payloads))] +} + +func NewVhostRouters(allowDuplicates bool) *VhostRouters { return &VhostRouters{ - RouterByDomain: make(map[string][]*VhostRouter), + allowDuplicates: allowDuplicates, + RouterByDomain: make(map[string][]*VhostRouter), } } @@ -33,8 +44,12 @@ func (r *VhostRouters) Add(domain, location string, payload interface{}) error { r.mutex.Lock() defer r.mutex.Unlock() - if _, exist := r.exist(domain, location); exist { - return ErrRouterConfigConflict + if vr, exist := r.exist(domain, location); exist { + if !r.allowDuplicates { + return ErrRouterConfigConflict + } + vr.payloads = append(vr.payloads, payload) + return nil } vrs, found := r.RouterByDomain[domain] @@ -43,10 +58,12 @@ func (r *VhostRouters) Add(domain, location string, payload interface{}) error { } vr := &VhostRouter{ - domain: domain, - location: location, - payload: payload, + domain: domain, + location: location, + allowDuplicates: r.allowDuplicates, + payloads: make([]interface{}, 1), } + vr.payloads[0] = payload vrs = append(vrs, vr) sort.Sort(sort.Reverse(ByLocation(vrs))) @@ -68,7 +85,41 @@ func (r *VhostRouters) Del(domain, location string) { newVrs = append(newVrs, vr) } } - r.RouterByDomain[domain] = newVrs + if len(newVrs) == 0 { + delete(r.RouterByDomain, domain) + } else { + r.RouterByDomain[domain] = newVrs + } +} + +func (r *VhostRouters) DelPayloadFromLocation(domain, location string, payload interface{}) { + r.mutex.Lock() + + vrs, found := r.RouterByDomain[domain] + if !found { + r.mutex.Unlock() + return + } + newPayloadsLen := -1 + for _, vr := range vrs { + if vr.location == location { + newPayloads := make([]interface{}, 0) + for _, payloadIter := range vr.payloads { + if payloadIter != payload { + newPayloads = append(newPayloads, payloadIter) + } + } + vr.payloads = newPayloads + newPayloadsLen = len(newPayloads) + break + } + } + + r.mutex.Unlock() + + if newPayloadsLen == 0 { + r.Del(domain, location) + } } func (r *VhostRouters) Get(host, path string) (vr *VhostRouter, exist bool) { @@ -80,7 +131,6 @@ func (r *VhostRouters) Get(host, path string) (vr *VhostRouter, exist bool) { return } - // can't support load balance, will to do for _, vr = range vrs { if strings.HasPrefix(path, vr.location) { return vr, true diff --git a/utils/vhost/vhost.go b/utils/vhost/vhost.go index fec3f525..1216a073 100644 --- a/utils/vhost/vhost.go +++ b/utils/vhost/vhost.go @@ -41,7 +41,7 @@ type VhostMuxer struct { registryRouter *VhostRouters } -func NewVhostMuxer(listener net.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, successFunc successFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) { +func NewVhostMuxer(listener net.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, successFunc successFunc, rewriteFunc hostRewriteFunc, timeout time.Duration, allowDuplicates bool) (mux *VhostMuxer, err error) { mux = &VhostMuxer{ listener: listener, timeout: timeout, @@ -49,7 +49,7 @@ func NewVhostMuxer(listener net.Listener, vhostFunc muxFunc, authFunc httpAuthFu authFunc: authFunc, successFunc: successFunc, rewriteFunc: rewriteFunc, - registryRouter: NewVhostRouters(), + registryRouter: NewVhostRouters(allowDuplicates), } go mux.run() return mux, nil @@ -94,7 +94,7 @@ func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) { // if not exist, then check the wildcard_domain such as *.example.com vr, found := v.registryRouter.Get(name, path) if found { - return vr.payload.(*Listener), true + return vr.getPayload().(*Listener), true } domainSplit := strings.Split(name, ".") @@ -112,7 +112,7 @@ func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) { vr, found = v.registryRouter.Get(name, path) if found { - return vr.payload.(*Listener), true + return vr.getPayload().(*Listener), true } domainSplit = domainSplit[1:] } @@ -223,7 +223,7 @@ func (l *Listener) Accept() (net.Conn, error) { } func (l *Listener) Close() error { - l.mux.registryRouter.Del(l.name, l.location) + l.mux.registryRouter.DelPayloadFromLocation(l.name, l.location, l) close(l.accept) return nil }