feat: Add success function to vhost

This commit is contained in:
Guy Lewin 2020-03-03 17:42:42 -05:00
parent 6fc7135cbd
commit a537bbd585
3 changed files with 16 additions and 7 deletions

View File

@ -31,7 +31,7 @@ type HttpConnectTcpMuxer struct {
} }
func NewHttpConnectTcpMuxer(listener net.Listener, timeout time.Duration) (*HttpConnectTcpMuxer, error) { func NewHttpConnectTcpMuxer(listener net.Listener, timeout time.Duration) (*HttpConnectTcpMuxer, error) {
mux, err := vhost.NewVhostMuxer(listener, getHostFromHttpConnect, nil, sendHttpOk, timeout) mux, err := vhost.NewVhostMuxer(listener, getHostFromHttpConnect, nil, sendHttpOk, nil, timeout)
return &HttpConnectTcpMuxer{mux}, err return &HttpConnectTcpMuxer{mux}, err
} }
@ -52,10 +52,8 @@ func readHttpConnectRequest(rd io.Reader) (host string, err error) {
return return
} }
func sendHttpOk(c net.Conn, _ string) (_ net.Conn, err error) { func sendHttpOk(c net.Conn) error {
okResp := util.OkResponse() return util.OkResponse().Write(c)
err = okResp.Write(c)
return c, err
} }
func getHostFromHttpConnect(c net.Conn) (_ net.Conn, _ map[string]string, err error) { func getHostFromHttpConnect(c net.Conn) (_ net.Conn, _ map[string]string, err error) {

View File

@ -48,7 +48,7 @@ type HttpsMuxer struct {
} }
func NewHttpsMuxer(listener net.Listener, timeout time.Duration) (*HttpsMuxer, error) { func NewHttpsMuxer(listener net.Listener, timeout time.Duration) (*HttpsMuxer, error) {
mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, nil, timeout) mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, nil, nil, timeout)
return &HttpsMuxer{mux}, err return &HttpsMuxer{mux}, err
} }

View File

@ -29,22 +29,25 @@ import (
type muxFunc func(net.Conn) (net.Conn, map[string]string, error) type muxFunc func(net.Conn) (net.Conn, map[string]string, error)
type httpAuthFunc func(net.Conn, string, string, string) (bool, error) type httpAuthFunc func(net.Conn, string, string, string) (bool, error)
type hostRewriteFunc func(net.Conn, string) (net.Conn, error) type hostRewriteFunc func(net.Conn, string) (net.Conn, error)
type successFunc func(net.Conn) error
type VhostMuxer struct { type VhostMuxer struct {
listener net.Listener listener net.Listener
timeout time.Duration timeout time.Duration
vhostFunc muxFunc vhostFunc muxFunc
authFunc httpAuthFunc authFunc httpAuthFunc
successFunc successFunc
rewriteFunc hostRewriteFunc rewriteFunc hostRewriteFunc
registryRouter *VhostRouters registryRouter *VhostRouters
} }
func NewVhostMuxer(listener net.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) { func NewVhostMuxer(listener net.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, successFunc successFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
mux = &VhostMuxer{ mux = &VhostMuxer{
listener: listener, listener: listener,
timeout: timeout, timeout: timeout,
vhostFunc: vhostFunc, vhostFunc: vhostFunc,
authFunc: authFunc, authFunc: authFunc,
successFunc: successFunc,
rewriteFunc: rewriteFunc, rewriteFunc: rewriteFunc,
registryRouter: NewVhostRouters(), registryRouter: NewVhostRouters(),
} }
@ -149,7 +152,15 @@ func (v *VhostMuxer) handle(c net.Conn) {
c.Close() c.Close()
return return
} }
xl := xlog.FromContextSafe(l.ctx) xl := xlog.FromContextSafe(l.ctx)
if v.successFunc != nil {
if err := v.successFunc(c); err != nil {
xl.Info("success func failure on vhost connection: %v", err)
c.Close()
return
}
}
// if authFunc is exist and userName/password is set // if authFunc is exist and userName/password is set
// then verify user access // then verify user access