diff --git a/client/control.go b/client/control.go index 5589817f..441ae68a 100644 --- a/client/control.go +++ b/client/control.go @@ -25,6 +25,7 @@ import ( "time" "github.com/fatedier/frp/client/proxy" + "github.com/fatedier/frp/models/auth" "github.com/fatedier/frp/models/config" "github.com/fatedier/frp/models/msg" frpNet "github.com/fatedier/frp/utils/net" @@ -82,13 +83,17 @@ type Control struct { // service context ctx context.Context + + // sets authentication based on selected method + authSetter auth.Setter } func NewControl(ctx context.Context, runId string, conn net.Conn, session *fmux.Session, clientCfg config.ClientCommonConf, pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.VisitorConf, - serverUDPPort int) *Control { + serverUDPPort int, + authSetter auth.Setter) *Control { // new xlog instance ctl := &Control{ @@ -107,6 +112,7 @@ func NewControl(ctx context.Context, runId string, conn net.Conn, session *fmux. serverUDPPort: serverUDPPort, xl: xlog.FromContextSafe(ctx), ctx: ctx, + authSetter: authSetter, } ctl.pm = proxy.NewProxyManager(ctl.ctx, ctl.sendCh, clientCfg, serverUDPPort) @@ -282,7 +288,12 @@ func (ctl *Control) msgHandler() { case <-hbSend.C: // send heartbeat to server xl.Debug("send heartbeat to server") - ctl.sendCh <- &msg.Ping{} + pingMsg := &msg.Ping{} + if err := ctl.authSetter.SetPing(pingMsg); err != nil { + xl.Warn("error during ping authentication: %v", err) + return + } + ctl.sendCh <- pingMsg case <-hbCheck.C: if time.Since(ctl.lastPong) > time.Duration(ctl.clientCfg.HeartBeatTimeout)*time.Second { xl.Warn("heartbeat timeout") diff --git a/client/service.go b/client/service.go index 5ad08855..18416d2e 100644 --- a/client/service.go +++ b/client/service.go @@ -26,11 +26,11 @@ import ( "time" "github.com/fatedier/frp/assets" + "github.com/fatedier/frp/models/auth" "github.com/fatedier/frp/models/config" "github.com/fatedier/frp/models/msg" "github.com/fatedier/frp/utils/log" frpNet "github.com/fatedier/frp/utils/net" - "github.com/fatedier/frp/utils/util" "github.com/fatedier/frp/utils/version" "github.com/fatedier/frp/utils/xlog" @@ -46,6 +46,9 @@ type Service struct { ctl *Control ctlMu sync.RWMutex + // Sets authentication based on selected method + authSetter auth.Setter + cfg config.ClientCommonConf pxyCfgs map[string]config.ProxyConf visitorCfgs map[string]config.VisitorConf @@ -70,6 +73,7 @@ func NewService(cfg config.ClientCommonConf, pxyCfgs map[string]config.ProxyConf ctx, cancel := context.WithCancel(context.Background()) svr = &Service{ + authSetter: auth.NewAuthSetter(cfg), cfg: cfg, cfgFile: cfgFile, pxyCfgs: pxyCfgs, @@ -105,7 +109,7 @@ func (svr *Service) Run() error { } } else { // login success - ctl := NewControl(svr.ctx, svr.runId, conn, session, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort) + ctl := NewControl(svr.ctx, svr.runId, conn, session, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort, svr.authSetter) ctl.Run() svr.ctlMu.Lock() svr.ctl = ctl @@ -159,7 +163,7 @@ func (svr *Service) keepControllerWorking() { // reconnect success, init delayTime delayTime = time.Second - ctl := NewControl(svr.ctx, svr.runId, conn, session, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort) + ctl := NewControl(svr.ctx, svr.runId, conn, session, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort, svr.authSetter) ctl.Run() svr.ctlMu.Lock() svr.ctl = ctl @@ -212,17 +216,20 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) { conn = stream } - now := time.Now().Unix() loginMsg := &msg.Login{ - Arch: runtime.GOARCH, - Os: runtime.GOOS, - PoolCount: svr.cfg.PoolCount, - User: svr.cfg.User, - Version: version.Full(), - PrivilegeKey: util.GetAuthKey(svr.cfg.Token, now), - Timestamp: now, - RunId: svr.runId, - Metas: svr.cfg.Metas, + Arch: runtime.GOARCH, + Os: runtime.GOOS, + PoolCount: svr.cfg.PoolCount, + User: svr.cfg.User, + Version: version.Full(), + Timestamp: time.Now().Unix(), + RunId: svr.runId, + Metas: svr.cfg.Metas, + } + + // Add auth + if err = svr.authSetter.SetLogin(loginMsg); err != nil { + return } if err = msg.WriteMsg(conn, loginMsg); err != nil { diff --git a/go.mod b/go.mod index a71a97c0..c43c3646 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.12 require ( github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 + github.com/coreos/go-oidc v2.2.1+incompatible github.com/fatedier/beego v0.0.0-20171024143340-6c6a4f5bd5eb github.com/fatedier/golib v0.0.0-20181107124048-ff8cd814b049 github.com/fatedier/kcp-go v2.0.4-0.20190803094908-fe8645b0a904+incompatible @@ -17,6 +18,7 @@ require ( github.com/mattn/go-runewidth v0.0.4 // indirect github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc github.com/pkg/errors v0.8.0 // indirect + github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/rakyll/statik v0.1.1 github.com/rodaine/table v1.0.0 github.com/spf13/cobra v0.0.3 @@ -28,6 +30,8 @@ require ( github.com/vaughan0/go-ini v0.0.0-20130923145212-a98ad7ee00ec github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae // indirect golang.org/x/net v0.0.0-20190724013045-ca1201d0de80 + golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d golang.org/x/text v0.3.2 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 + gopkg.in/square/go-jose.v2 v2.4.1 // indirect ) diff --git a/models/auth/auth.go b/models/auth/auth.go new file mode 100644 index 00000000..3b809b79 --- /dev/null +++ b/models/auth/auth.go @@ -0,0 +1,65 @@ +// Copyright 2020 guylewin, guy@lewin.co.il +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "github.com/fatedier/frp/models/config" + "github.com/fatedier/frp/models/consts" + "github.com/fatedier/frp/models/msg" +) + +type Setter interface { + SetLogin(*msg.Login) error + SetPing(*msg.Ping) error +} + +func NewAuthSetter(cfg config.ClientCommonConf) (authProvider Setter) { + switch cfg.AuthenticationMethod { + case consts.TokenAuthMethod: + authProvider = NewTokenAuth(cfg.Token) + case consts.OidcAuthMethod: + authProvider = NewOidcAuthSetter( + cfg.OidcClientId, + cfg.OidcClientSecret, + cfg.OidcAudience, + cfg.OidcTokenEndpointUrl, + cfg.AuthenticateHeartBeats, + ) + } + + return +} + +type Verifier interface { + VerifyLogin(*msg.Login) error + VerifyPing(*msg.Ping) error +} + +func NewAuthVerifier(cfg config.ServerCommonConf) (authVerifier Verifier) { + switch cfg.AuthenticationMethod { + case consts.TokenAuthMethod: + authVerifier = NewTokenAuth(cfg.Token) + case consts.OidcAuthMethod: + authVerifier = NewOidcAuthVerifier( + cfg.OidcIssuer, + cfg.OidcAudience, + cfg.OidcSkipExpiryCheck, + cfg.OidcSkipIssuerCheck, + cfg.AuthenticateHeartBeats, + ) + } + + return +} diff --git a/models/auth/oidc.go b/models/auth/oidc.go new file mode 100644 index 00000000..945d7080 --- /dev/null +++ b/models/auth/oidc.go @@ -0,0 +1,118 @@ +// Copyright 2020 guylewin, guy@lewin.co.il +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "fmt" + + "github.com/fatedier/frp/models/msg" + + "github.com/coreos/go-oidc" + "golang.org/x/oauth2/clientcredentials" +) + +type OidcAuthProvider struct { + tokenGenerator *clientcredentials.Config + authenticateHeartBeats bool +} + +func NewOidcAuthSetter(clientId string, clientSecret string, audience string, tokenEndpointUrl string, authenticateHeartBeats bool) *OidcAuthProvider { + tokenGenerator := &clientcredentials.Config{ + ClientID: clientId, + ClientSecret: clientSecret, + Scopes: []string{audience}, + TokenURL: tokenEndpointUrl, + } + + return &OidcAuthProvider{ + tokenGenerator: tokenGenerator, + authenticateHeartBeats: authenticateHeartBeats, + } +} + +func (auth *OidcAuthProvider) SetLogin(loginMsg *msg.Login) (err error) { + tokenObj, err := auth.tokenGenerator.Token(context.Background()) + if tokenObj == nil { + return fmt.Errorf("couldn't generate OIDC token for login: %s", err) + } + loginMsg.PrivilegeKey = tokenObj.AccessToken + return +} + +func (auth *OidcAuthProvider) SetPing(pingMsg *msg.Ping) (err error) { + if !auth.authenticateHeartBeats { + // if heartbeat authentication is disabled - don't set + return nil + } + + tokenObj, err := auth.tokenGenerator.Token(context.Background()) + if tokenObj == nil { + return fmt.Errorf("couldn't generate OIDC token for ping: %s", err) + } + pingMsg.PrivilegeKey = tokenObj.AccessToken + return +} + +type OidcAuthConsumer struct { + verifier *oidc.IDTokenVerifier + authenticateHeartBeats bool + subjectFromLogin string +} + +func NewOidcAuthVerifier(issuer string, audience string, skipExpiryCheck bool, skipIssuerCheck bool, authenticateHeartBeats bool) *OidcAuthConsumer { + provider, err := oidc.NewProvider(context.Background(), issuer) + if err != nil { + panic(err) + } + verifierConf := oidc.Config{ + ClientID: audience, + SkipClientIDCheck: audience == "", + SkipExpiryCheck: skipExpiryCheck, + SkipIssuerCheck: skipIssuerCheck, + } + return &OidcAuthConsumer{ + verifier: provider.Verifier(&verifierConf), + authenticateHeartBeats: authenticateHeartBeats, + } +} + +func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) { + token, err := auth.verifier.Verify(context.Background(), loginMsg.PrivilegeKey) + if token != nil { + auth.subjectFromLogin = token.Subject + return + } + return fmt.Errorf("invalid OIDC token in login: %v", err) +} + +func (auth *OidcAuthConsumer) VerifyPing(pingMsg *msg.Ping) (err error) { + if !auth.authenticateHeartBeats { + // if heartbeat authentication is disabled - don't verify + return nil + } + + token, err := auth.verifier.Verify(context.Background(), pingMsg.PrivilegeKey) + if token == nil { + return fmt.Errorf("invalid OIDC token in ping: %v", err) + } + if token.Subject != auth.subjectFromLogin { + return fmt.Errorf("received different OIDC subject in login and ping. "+ + "original subject: %s, "+ + "new subject: %s", + auth.subjectFromLogin, token.Subject) + } + return +} diff --git a/models/auth/token.go b/models/auth/token.go new file mode 100644 index 00000000..108cf114 --- /dev/null +++ b/models/auth/token.go @@ -0,0 +1,58 @@ +// Copyright 2020 guylewin, guy@lewin.co.il +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "fmt" + + "github.com/fatedier/frp/models/msg" + "github.com/fatedier/frp/utils/util" +) + +type TokenAuthProviderConsumer struct { + Token string +} + +func NewTokenAuth(token string) *TokenAuthProviderConsumer { + return &TokenAuthProviderConsumer{ + Token: token, + } +} + +func (auth *TokenAuthProviderConsumer) SetLogin(loginMsg *msg.Login) (err error) { + loginMsg.PrivilegeKey = util.GetAuthKey(auth.Token, loginMsg.Timestamp) + return nil +} + +func (auth *TokenAuthProviderConsumer) SetPing(*msg.Ping) error { + // ping doesn't include authentication in token method + return nil +} + +type TokenAuthConsumer struct { + Token string +} + +func (auth *TokenAuthProviderConsumer) VerifyLogin(loginMsg *msg.Login) error { + if util.GetAuthKey(auth.Token, loginMsg.Timestamp) != loginMsg.PrivilegeKey { + return fmt.Errorf("token in login doesn't match token from configuration") + } + return nil +} + +func (auth *TokenAuthProviderConsumer) VerifyPing(*msg.Ping) error { + // ping doesn't include authentication in token method + return nil +} diff --git a/models/config/client_common.go b/models/config/client_common.go index 2b5006b4..58e2bd66 100644 --- a/models/config/client_common.go +++ b/models/config/client_common.go @@ -60,6 +60,31 @@ type ClientCommonConf struct { // to the server. The server must have a matching token for authorization // to succeed. By default, this value is "". Token string `json:"token"` + // AuthenticationMethod specifies what authentication method to use to + // authenticate frpc with frps. If "token" is specified - token will be + // read into login message. If "oidc" is specified - OIDC (Open ID Connect) + // token will be issued using OIDC settings. By default, this value is "token". + AuthenticationMethod string `json:"authentication_method"` + // AuthenticateHeartBeats specifies whether to include authentication token in + // heartbeats sent to frps. By default, this value is false. + AuthenticateHeartBeats bool `json:"authenticate_heartbeats"` + + // OidcClientId specifies the client ID to use to get a token in OIDC + // authentication if AuthenticationMethod == "oidc". By default, this value + // is "". + OidcClientId string `json:"oidc_client_id"` + // OidcClientSecret specifies the client secret to use to get a token in OIDC + // authentication if AuthenticationMethod == "oidc". By default, this value + // is "". + OidcClientSecret string `json:"oidc_client_secret"` + // OidcAudience specifies the audience of the token in OIDC authentication + //if AuthenticationMethod == "oidc". By default, this value is "". + OidcAudience string `json:"oidc_audience"` + // OidcTokenEndpointUrl specifies the URL which implements OIDC Token Endpoint. + // It will be used to get an OIDC token if AuthenticationMethod == "oidc". + // By default, this value is "". + OidcTokenEndpointUrl string `json:"oidc_token_endpoint_url"` + // AdminAddr specifies the address that the admin server binds to. By // default, this value is "127.0.0.1". AdminAddr string `json:"admin_addr"` @@ -122,31 +147,37 @@ type ClientCommonConf struct { // GetDefaultClientConf returns a client configuration with default values. func GetDefaultClientConf() ClientCommonConf { return ClientCommonConf{ - ServerAddr: "0.0.0.0", - ServerPort: 7000, - HttpProxy: os.Getenv("http_proxy"), - LogFile: "console", - LogWay: "console", - LogLevel: "info", - LogMaxDays: 3, - DisableLogColor: false, - Token: "", - AdminAddr: "127.0.0.1", - AdminPort: 0, - AdminUser: "", - AdminPwd: "", - AssetsDir: "", - PoolCount: 1, - TcpMux: true, - User: "", - DnsServer: "", - LoginFailExit: true, - Start: make(map[string]struct{}), - Protocol: "tcp", - TLSEnable: false, - HeartBeatInterval: 30, - HeartBeatTimeout: 90, - Metas: make(map[string]string), + ServerAddr: "0.0.0.0", + ServerPort: 7000, + HttpProxy: os.Getenv("http_proxy"), + LogFile: "console", + LogWay: "console", + LogLevel: "info", + LogMaxDays: 3, + DisableLogColor: false, + Token: "", + AuthenticationMethod: "token", + AuthenticateHeartBeats: false, + OidcClientId: "", + OidcClientSecret: "", + OidcAudience: "", + OidcTokenEndpointUrl: "", + AdminAddr: "127.0.0.1", + AdminPort: 0, + AdminUser: "", + AdminPwd: "", + AssetsDir: "", + PoolCount: 1, + TcpMux: true, + User: "", + DnsServer: "", + LoginFailExit: true, + Start: make(map[string]struct{}), + Protocol: "tcp", + TLSEnable: false, + HeartBeatInterval: 30, + HeartBeatTimeout: 90, + Metas: make(map[string]string), } } @@ -207,6 +238,32 @@ func UnmarshalClientConfFromIni(content string) (cfg ClientCommonConf, err error cfg.Token = tmpStr } + if tmpStr, ok = conf.Get("common", "authentication_method"); ok { + cfg.AuthenticationMethod = tmpStr + } + + if tmpStr, ok = conf.Get("common", "authenticate_heartbeats"); ok && tmpStr == "true" { + cfg.AuthenticateHeartBeats = true + } else { + cfg.AuthenticateHeartBeats = false + } + + if tmpStr, ok = conf.Get("common", "oidc_client_id"); ok { + cfg.OidcClientId = tmpStr + } + + if tmpStr, ok = conf.Get("common", "oidc_client_secret"); ok { + cfg.OidcClientSecret = tmpStr + } + + if tmpStr, ok = conf.Get("common", "oidc_audience"); ok { + cfg.OidcAudience = tmpStr + } + + if tmpStr, ok = conf.Get("common", "oidc_token_endpoint_url"); ok { + cfg.OidcTokenEndpointUrl = tmpStr + } + if tmpStr, ok = conf.Get("common", "admin_addr"); ok { cfg.AdminAddr = tmpStr } diff --git a/models/config/server_common.go b/models/config/server_common.go index 20e92f9c..794337ae 100644 --- a/models/config/server_common.go +++ b/models/config/server_common.go @@ -105,6 +105,34 @@ type ServerCommonConf struct { // received from clients. Clients must have a matching token to be // authorized to use the server. By default, this value is "". Token string `json:"token"` + // AuthenticationMethod specifies what authentication method to use to + // authenticate frpc with frps. If "token" is specified - token comparison + // will be used. If "oidc" is specified - OIDC (Open ID Connect) will be + // used. By default, this value is "token". + AuthenticationMethod string `json:"authentication_method"` + // AuthenticateHeartBeats specifies whether to expect and verify authentication + // token in heartbeats sent from frpc. By default, this value is false. + AuthenticateHeartBeats bool `json:"authenticate_heartbeats"` + + // OidcIssuer specifies the issuer to verify OIDC tokens with. This issuer + // will be used to load public keys to verify signature and will be compared + // with the issuer claim in the OIDC token. It will be used if + // AuthenticationMethod == "oidc". By default, this value is "". + OidcIssuer string `json:"oidc_issuer"` + // OidcAudience specifies the audience OIDC tokens should contain when validated. + // If this value is empty, audience ("client ID") verification will be skipped. + // It will be used when AuthenticationMethod == "oidc". By default, this + // value is "". + OidcAudience string `json:"oidc_audience"` + // OidcSkipExpiryCheck specifies whether to skip checking if the OIDC token is + // expired. It will be used when AuthenticationMethod == "oidc". By default, this + // value is false. + OidcSkipExpiryCheck bool `json:"oidc_skip_expiry_check"` + // OidcSkipIssuerCheck specifies whether to skip checking if the OIDC token's + // issuer claim matches the issuer specified in OidcIssuer. It will be used when + // AuthenticationMethod == "oidc". By default, this value is false. + OidcSkipIssuerCheck bool `json:"oidc_skip_issuer_check"` + // SubDomainHost specifies the domain that will be attached to sub-domains // requested by the client when using Vhost proxying. For example, if this // value is set to "frps.com" and the client requested the subdomain @@ -169,6 +197,12 @@ func GetDefaultServerConf() ServerCommonConf { DisableLogColor: false, DetailedErrorsToClient: true, Token: "", + AuthenticationMethod: "token", + AuthenticateHeartBeats: false, + OidcIssuer: "", + OidcAudience: "", + OidcSkipExpiryCheck: false, + OidcSkipIssuerCheck: false, SubDomainHost: "", TcpMux: true, AllowPorts: make(map[int]struct{}), @@ -330,6 +364,36 @@ func UnmarshalServerConfFromIni(content string) (cfg ServerCommonConf, err error cfg.Token, _ = conf.Get("common", "token") + if tmpStr, ok = conf.Get("common", "authentication_method"); ok { + cfg.AuthenticationMethod = tmpStr + } + + if tmpStr, ok = conf.Get("common", "authenticate_heartbeats"); ok && tmpStr == "true" { + cfg.AuthenticateHeartBeats = true + } else { + cfg.AuthenticateHeartBeats = false + } + + if tmpStr, ok = conf.Get("common", "oidc_issuer"); ok { + cfg.OidcIssuer = tmpStr + } + + if tmpStr, ok = conf.Get("common", "oidc_audience"); ok { + cfg.OidcAudience = tmpStr + } + + if tmpStr, ok = conf.Get("common", "oidc_skip_expiry_check"); ok && tmpStr == "true" { + cfg.OidcSkipExpiryCheck = true + } else { + cfg.OidcSkipExpiryCheck = false + } + + if tmpStr, ok = conf.Get("common", "oidc_skip_issuer_check"); ok && tmpStr == "true" { + cfg.OidcSkipIssuerCheck = true + } else { + cfg.OidcSkipIssuerCheck = false + } + if allowPortsStr, ok := conf.Get("common", "allow_ports"); ok { // e.g. 1000-2000,2001,2002,3000-4000 ports, errRet := util.ParseRangeNumbers(allowPortsStr) diff --git a/models/consts/consts.go b/models/consts/consts.go index 9bf5880b..f3c480fe 100644 --- a/models/consts/consts.go +++ b/models/consts/consts.go @@ -29,4 +29,8 @@ var ( HttpsProxy string = "https" StcpProxy string = "stcp" XtcpProxy string = "xtcp" + + // authentication method + TokenAuthMethod string = "token" + OidcAuthMethod string = "oidc" ) diff --git a/models/msg/msg.go b/models/msg/msg.go index ce41c9ec..cba710fd 100644 --- a/models/msg/msg.go +++ b/models/msg/msg.go @@ -148,6 +148,7 @@ type NewVisitorConnResp struct { } type Ping struct { + PrivilegeKey string `json:"privilege_key"` } type Pong struct { diff --git a/server/control.go b/server/control.go index 4b9227ab..5a8d6b61 100644 --- a/server/control.go +++ b/server/control.go @@ -23,6 +23,7 @@ import ( "sync" "time" + "github.com/fatedier/frp/models/auth" "github.com/fatedier/frp/models/config" "github.com/fatedier/frp/models/consts" frpErr "github.com/fatedier/frp/models/errors" @@ -94,6 +95,9 @@ type Control struct { // stats collector to store stats info of clients and proxies statsCollector stats.Collector + // verifies authentication based on selected method + authVerifier auth.Verifier + // login message loginMsg *msg.Login @@ -149,6 +153,7 @@ func NewControl( pxyManager *proxy.ProxyManager, pluginManager *plugin.Manager, statsCollector stats.Collector, + authVerifier auth.Verifier, ctlConn net.Conn, loginMsg *msg.Login, serverCfg config.ServerCommonConf, @@ -163,6 +168,7 @@ func NewControl( pxyManager: pxyManager, pluginManager: pluginManager, statsCollector: statsCollector, + authVerifier: authVerifier, conn: ctlConn, loginMsg: loginMsg, sendCh: make(chan msg.Message, 10), @@ -454,6 +460,9 @@ func (ctl *Control) manager() { ctl.CloseProxy(m) xl.Info("close proxy [%s] success", m.ProxyName) case *msg.Ping: + if err := ctl.authVerifier.VerifyPing(m); err != nil { + xl.Warn("received invalid ping: %v", err) + } ctl.lastPing = time.Now() xl.Debug("receive heartbeat") ctl.sendCh <- &msg.Pong{} diff --git a/server/service.go b/server/service.go index 1ad7e281..7ad5ac70 100644 --- a/server/service.go +++ b/server/service.go @@ -30,6 +30,7 @@ import ( "time" "github.com/fatedier/frp/assets" + "github.com/fatedier/frp/models/auth" "github.com/fatedier/frp/models/config" "github.com/fatedier/frp/models/msg" "github.com/fatedier/frp/models/nathole" @@ -86,6 +87,9 @@ type Service struct { // All resource managers and controllers rc *controller.ResourceController + // Verifies authentication based on selected method + authVerifier auth.Verifier + // stats collector to store server and proxies stats info statsCollector stats.Collector @@ -105,6 +109,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { UdpPortManager: ports.NewPortManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts), }, httpVhostRouter: vhost.NewVhostRouters(), + authVerifier: auth.NewAuthVerifier(cfg), tlsConfig: generateTLSConfig(), cfg: cfg, } @@ -399,12 +404,11 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err } // Check auth. - if util.GetAuthKey(svr.cfg.Token, loginMsg.Timestamp) != loginMsg.PrivilegeKey { - err = fmt.Errorf("authorization failed") + if err = svr.authVerifier.VerifyLogin(loginMsg); err != nil { return } - ctl := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, svr.statsCollector, ctlConn, loginMsg, svr.cfg) + ctl := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, svr.statsCollector, svr.authVerifier, ctlConn, loginMsg, svr.cfg) if oldCtl := svr.ctlManager.Add(loginMsg.RunId, ctl); oldCtl != nil { oldCtl.allShutdown.WaitDone()