diff --git a/client/control.go b/client/control.go index 067fe37f..d51897f8 100644 --- a/client/control.go +++ b/client/control.go @@ -234,9 +234,15 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) { } } - address := net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort)) - conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig, ctl.clientCfg.DisableCustomTLSFirstByte) + opts := []frpNet.DialOption{ + frpNet.WithProxy(ctl.clientCfg.HTTPProxy), + frpNet.WithProtocol(ctl.clientCfg.Protocol), + frpNet.WithRemoteAddress(net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort))), + frpNet.WithTLSConfig(tlsConfig), + frpNet.WithDisableCustomTLSHeadByte(ctl.clientCfg.DisableCustomTLSFirstByte), + } + conn, err = frpNet.Dial(opts...) if err != nil { xl.Warn("start new connection to server error: %v", err) return diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index 47ab03ca..1d6585f4 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -790,7 +790,11 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf return } - localConn, err := frpNet.ConnectServer("tcp", fmt.Sprintf("%s:%d", localInfo.LocalIP, localInfo.LocalPort)) + opts := []frpNet.DialOption{ + frpNet.WithRemoteAddress(fmt.Sprintf("%s:%d", localInfo.LocalIP, localInfo.LocalPort)), + } + + localConn, err := frpNet.Dial(opts...) if err != nil { workConn.Close() xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err) diff --git a/client/service.go b/client/service.go index 8b880034..445abac1 100644 --- a/client/service.go +++ b/client/service.go @@ -228,8 +228,16 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) { } } - address := net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort)) - conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig, svr.cfg.DisableCustomTLSFirstByte) + opts := []frpNet.DialOption{ + frpNet.WithProxy(svr.cfg.HTTPProxy), + frpNet.WithProtocol(svr.cfg.Protocol), + frpNet.WithBindAddress(net.JoinHostPort(svr.cfg.BindAddr, strconv.Itoa(svr.cfg.BindPort))), + frpNet.WithRemoteAddress(net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort))), + frpNet.WithTLSConfig(tlsConfig), + frpNet.WithDisableCustomTLSHeadByte(svr.cfg.DisableCustomTLSFirstByte), + } + + conn, err = frpNet.Dial(opts...) if err != nil { return } diff --git a/pkg/config/client.go b/pkg/config/client.go index b2efb79a..e3355e00 100644 --- a/pkg/config/client.go +++ b/pkg/config/client.go @@ -38,6 +38,10 @@ type ClientCommonConf struct { // ServerPort specifies the port to connect to the server on. By default, // this value is 7000. ServerPort int `ini:"server_port" json:"server_port"` + + BindAddr string `ini:"bind_addr" json:"bind_addr"` + BindPort int `ini:"bind_port" json:"bind_port"` + // HTTPProxy specifies a proxy address to connect to the server through. If // this value is "", the server will be connected to directly. By default, // this value is read from the "http_proxy" environment variable. diff --git a/pkg/util/net/conn.go b/pkg/util/net/conn.go index ccb199e5..f3d8caee 100644 --- a/pkg/util/net/conn.go +++ b/pkg/util/net/conn.go @@ -16,18 +16,13 @@ package net import ( "context" - "crypto/tls" "errors" - "fmt" "io" "net" "sync/atomic" "time" "github.com/fatedier/frp/pkg/util/xlog" - - gnet "github.com/fatedier/golib/net" - kcp "github.com/fatedier/kcp-go" ) type ContextGetter interface { @@ -188,56 +183,3 @@ func (statsConn *StatsConn) Close() (err error) { } return } - -func ConnectServer(protocol string, addr string) (c net.Conn, err error) { - switch protocol { - case "tcp": - return net.Dial("tcp", addr) - case "kcp": - kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3) - if errRet != nil { - err = errRet - return - } - kcpConn.SetStreamMode(true) - kcpConn.SetWriteDelay(true) - kcpConn.SetNoDelay(1, 20, 2, 1) - kcpConn.SetWindowSize(128, 512) - kcpConn.SetMtu(1350) - kcpConn.SetACKNoDelay(false) - kcpConn.SetReadBuffer(4194304) - kcpConn.SetWriteBuffer(4194304) - c = kcpConn - return - default: - return nil, fmt.Errorf("unsupport protocol: %s", protocol) - } -} - -func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) { - switch protocol { - case "tcp": - return gnet.DialTcpByProxy(proxyURL, addr) - case "kcp": - // http proxy is not supported for kcp - return ConnectServer(protocol, addr) - case "websocket": - return ConnectWebsocketServer(addr) - default: - return nil, fmt.Errorf("unsupport protocol: %s", protocol) - } -} - -func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (c net.Conn, err error) { - c, err = ConnectServerByProxy(proxyURL, protocol, addr) - if err != nil { - return - } - - if tlsConfig == nil { - return - } - - c = WrapTLSClientConn(c, tlsConfig, disableCustomTLSHeadByte) - return -} diff --git a/pkg/util/net/dial.go b/pkg/util/net/dial.go new file mode 100644 index 00000000..eefce37a --- /dev/null +++ b/pkg/util/net/dial.go @@ -0,0 +1,164 @@ +package net + +import ( + "crypto/tls" + "fmt" + "net" + "net/url" + "time" + + gnet "github.com/fatedier/golib/net" + kcp "github.com/fatedier/kcp-go" + "golang.org/x/net/websocket" +) + +type dialOptions struct { + proxyURL string + protocol string + laddr string + addr string + tlsConfig *tls.Config + disableCustomTLSHeadByte bool +} + +// DialOption configures how we set up the connection. +type DialOption interface { + apply(*dialOptions) +} + +type EmptyDialOption struct{} + +func (EmptyDialOption) apply(*dialOptions) {} + +type funcDialOption struct { + f func(*dialOptions) +} + +func (fdo *funcDialOption) apply(do *dialOptions) { + fdo.f(do) +} + +func newFuncDialOption(f func(*dialOptions)) *funcDialOption { + return &funcDialOption{ + f: f, + } +} + +func DefaultDialOptions() dialOptions { + return dialOptions{ + protocol: "tcp", + } +} + +func WithProxy(proxyURL string) DialOption { + return newFuncDialOption(func(do *dialOptions) { + do.proxyURL = proxyURL + }) +} + +func WithBindAddress(laddr string) DialOption { + return newFuncDialOption(func(do *dialOptions) { + do.laddr = laddr + }) +} + +func WithTLSConfig(tlsConfig *tls.Config) DialOption { + return newFuncDialOption(func(do *dialOptions) { + do.tlsConfig = tlsConfig + }) +} + +func WithRemoteAddress(addr string) DialOption { + return newFuncDialOption(func(do *dialOptions) { + do.addr = addr + }) +} + +func WithDisableCustomTLSHeadByte(disableCustomTLSHeadByte bool) DialOption { + return newFuncDialOption(func(do *dialOptions) { + do.disableCustomTLSHeadByte = disableCustomTLSHeadByte + }) +} + +func WithProtocol(protocol string) DialOption { + return newFuncDialOption(func(do *dialOptions) { + do.protocol = protocol + }) +} + +func Dial(opts ...DialOption) (c net.Conn, err error) { + op := DefaultDialOptions() + + for _, opt := range opts { + opt.apply(&op) + } + + c, err = dialServer(op.proxyURL, op.protocol, op.addr) + if err != nil { + return + } + + if op.tlsConfig == nil { + return + } + + c = WrapTLSClientConn(c, op.tlsConfig, op.disableCustomTLSHeadByte) + return +} + +func dialServer(proxyURL string, protocol string, addr string) (c net.Conn, err error) { + var d = &net.Dialer{} + + switch protocol { + case "tcp": + return gnet.DialTcpByProxy(d, proxyURL, addr) + case "kcp": + return DialKCPServer(addr) + case "websocket": + return DialWebsocketServer(d, addr) + default: + return nil, fmt.Errorf("unsupport protocol: %s", protocol) + } +} + +func DialKCPServer(addr string) (net.Conn, error) { + // http proxy is not supported for kcp + kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3) + if errRet != nil { + return nil, errRet + } + kcpConn.SetStreamMode(true) + kcpConn.SetWriteDelay(true) + kcpConn.SetNoDelay(1, 20, 2, 1) + kcpConn.SetWindowSize(128, 512) + kcpConn.SetMtu(1350) + kcpConn.SetACKNoDelay(false) + kcpConn.SetReadBuffer(4194304) + kcpConn.SetWriteBuffer(4194304) + + return kcpConn, nil +} + +// addr: domain:port +func DialWebsocketServer(d *net.Dialer, addr string) (net.Conn, error) { + addr = "ws://" + addr + FrpWebsocketPath + uri, err := url.Parse(addr) + if err != nil { + return nil, err + } + + origin := "http://" + uri.Host + cfg, err := websocket.NewConfig(addr, origin) + if err != nil { + return nil, err + } + + cfg.Dialer = d + cfg.Dialer.Timeout = 10 * time.Second + + conn, err := websocket.DialConfig(cfg) + if err != nil { + return nil, err + } + return conn, nil +} diff --git a/pkg/util/net/websocket.go b/pkg/util/net/websocket.go index 36b6440c..7030787e 100644 --- a/pkg/util/net/websocket.go +++ b/pkg/util/net/websocket.go @@ -5,8 +5,6 @@ import ( "fmt" "net" "net/http" - "net/url" - "time" "golang.org/x/net/websocket" ) @@ -77,27 +75,3 @@ func (p *WebsocketListener) Close() error { func (p *WebsocketListener) Addr() net.Addr { return p.ln.Addr() } - -// addr: domain:port -func ConnectWebsocketServer(addr string) (net.Conn, error) { - addr = "ws://" + addr + FrpWebsocketPath - uri, err := url.Parse(addr) - if err != nil { - return nil, err - } - - origin := "http://" + uri.Host - cfg, err := websocket.NewConfig(addr, origin) - if err != nil { - return nil, err - } - cfg.Dialer = &net.Dialer{ - Timeout: 10 * time.Second, - } - - conn, err := websocket.DialConfig(cfg) - if err != nil { - return nil, err - } - return conn, nil -} diff --git a/test/e2e/pkg/request/request.go b/test/e2e/pkg/request/request.go index 5792d530..a0bc3834 100644 --- a/test/e2e/pkg/request/request.go +++ b/test/e2e/pkg/request/request.go @@ -13,7 +13,8 @@ import ( "time" "github.com/fatedier/frp/test/e2e/pkg/rpc" - libnet "github.com/fatedier/golib/net" + + frpNet "github.com/fatedier/frp/pkg/util/net" ) type Request struct { @@ -141,7 +142,13 @@ func (r *Request) Do() (*Response, error) { if r.protocol != "tcp" { return nil, fmt.Errorf("only tcp protocol is allowed for proxy") } - conn, err = libnet.DialTcpByProxy(r.proxyURL, addr) + + opts := []frpNet.DialOption{ + frpNet.WithProxy(r.proxyURL), + frpNet.WithRemoteAddress(addr), + } + + conn, err = frpNet.Dial(opts...) if err != nil { return nil, err }