diff --git a/utils/tcpmux/httpconnect.go b/utils/tcpmux/httpconnect.go index 92d3eb3e..af0a39f9 100644 --- a/utils/tcpmux/httpconnect.go +++ b/utils/tcpmux/httpconnect.go @@ -31,7 +31,7 @@ type HttpConnectTcpMuxer struct { } func NewHttpConnectTcpMuxer(listener net.Listener, timeout time.Duration) (*HttpConnectTcpMuxer, error) { - mux, err := vhost.NewVhostMuxer(listener, getHostFromHttpConnect, nil, sendHttpOk, timeout) + mux, err := vhost.NewVhostMuxer(listener, getHostFromHttpConnect, nil, sendHttpOk, nil, timeout) return &HttpConnectTcpMuxer{mux}, err } @@ -52,10 +52,8 @@ func readHttpConnectRequest(rd io.Reader) (host string, err error) { return } -func sendHttpOk(c net.Conn, _ string) (_ net.Conn, err error) { - okResp := util.OkResponse() - err = okResp.Write(c) - return c, err +func sendHttpOk(c net.Conn) error { + return util.OkResponse().Write(c) } func getHostFromHttpConnect(c net.Conn) (_ net.Conn, _ map[string]string, err error) { diff --git a/utils/vhost/https.go b/utils/vhost/https.go index 53177019..41bd01ce 100644 --- a/utils/vhost/https.go +++ b/utils/vhost/https.go @@ -48,7 +48,7 @@ type HttpsMuxer struct { } func NewHttpsMuxer(listener net.Listener, timeout time.Duration) (*HttpsMuxer, error) { - mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, nil, timeout) + mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, nil, nil, timeout) return &HttpsMuxer{mux}, err } diff --git a/utils/vhost/vhost.go b/utils/vhost/vhost.go index 57f82394..ad322be6 100644 --- a/utils/vhost/vhost.go +++ b/utils/vhost/vhost.go @@ -29,22 +29,25 @@ import ( type muxFunc func(net.Conn) (net.Conn, map[string]string, error) type httpAuthFunc func(net.Conn, string, string, string) (bool, error) type hostRewriteFunc func(net.Conn, string) (net.Conn, error) +type successFunc func(net.Conn) error type VhostMuxer struct { listener net.Listener timeout time.Duration vhostFunc muxFunc authFunc httpAuthFunc + successFunc successFunc rewriteFunc hostRewriteFunc registryRouter *VhostRouters } -func NewVhostMuxer(listener net.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) { +func NewVhostMuxer(listener net.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, successFunc successFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) { mux = &VhostMuxer{ listener: listener, timeout: timeout, vhostFunc: vhostFunc, authFunc: authFunc, + successFunc: successFunc, rewriteFunc: rewriteFunc, registryRouter: NewVhostRouters(), } @@ -149,7 +152,15 @@ func (v *VhostMuxer) handle(c net.Conn) { c.Close() return } + xl := xlog.FromContextSafe(l.ctx) + if v.successFunc != nil { + if err := v.successFunc(c); err != nil { + xl.Info("success func failure on vhost connection: %v", err) + c.Close() + return + } + } // if authFunc is exist and userName/password is set // then verify user access