From ddf138435c5ddc569ff184b5581ae5b8e0523621 Mon Sep 17 00:00:00 2001 From: hu198021688500 Date: Wed, 15 May 2024 16:07:11 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E5=8D=8E=E4=B8=BA=E6=B5=8F?= =?UTF-8?q?=E8=A7=88=E5=99=A8=E6=97=A0=E6=B3=95=E4=B8=8B=E8=BD=BD=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/util/net/http.go | 72 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 5 deletions(-) diff --git a/pkg/util/net/http.go b/pkg/util/net/http.go index e9fc5260..72d4a534 100644 --- a/pkg/util/net/http.go +++ b/pkg/util/net/http.go @@ -16,25 +16,36 @@ package net import ( "compress/gzip" + "github.com/google/uuid" "io" "net/http" "strings" "time" "github.com/fatedier/frp/pkg/util/util" + + "github.com/fatedier/frp/pkg/util/log" ) type HTTPAuthMiddleware struct { user string passwd string authFailDelay time.Duration + + expires time.Duration + sessions map[string]time.Time } func NewHTTPAuthMiddleware(user, passwd string) *HTTPAuthMiddleware { - return &HTTPAuthMiddleware{ + middleware := &HTTPAuthMiddleware{ user: user, passwd: passwd, + + expires: 10 * time.Minute, + sessions: make(map[string]time.Time), } + middleware.cleanSession() + return middleware } func (authMid *HTTPAuthMiddleware) SetAuthFailDelay(delay time.Duration) *HTTPAuthMiddleware { @@ -42,12 +53,63 @@ func (authMid *HTTPAuthMiddleware) SetAuthFailDelay(delay time.Duration) *HTTPAu return authMid } +func (authMid *HTTPAuthMiddleware) signIn(w http.ResponseWriter, r *http.Request) bool { + reqUser, reqPasswd, hasAuth := r.BasicAuth() + if (authMid.user == "" && authMid.passwd == "") || + (hasAuth && util.ConstantTimeEqString(reqUser, authMid.user) && + util.ConstantTimeEqString(reqPasswd, authMid.passwd)) { + sessionToken := uuid.NewString() + expiresAt := time.Now().Add(authMid.expires) + + authMid.sessions[sessionToken] = expiresAt + http.SetCookie(w, &http.Cookie{ + Name: "session_token", + Value: sessionToken, + Expires: expiresAt, + }) + log.Debugf("signIn success and set cookie %s", sessionToken) + return true + } else { + log.Debugf("signIn fail") + return false + } +} + +func (authMid *HTTPAuthMiddleware) auth(r *http.Request) bool { + c, err := r.Cookie("session_token") + if err != nil { + log.Debugf("get cookie error: %v", err) + return false + } + _, exists := authMid.sessions[c.Value] + if exists { + log.Debugf("exist session %s and refresh it", c.Value) + authMid.sessions[c.Value] = time.Now().Add(authMid.expires) + } + return exists +} + +func (authMid *HTTPAuthMiddleware) cleanSession() { + ticker := time.NewTicker(authMid.expires) + go func() { + for { + <-ticker.C + log.Debugf("start clean session...") + for k, v := range authMid.sessions { + if v.Before(time.Now()) { + log.Debugf("delete session %s", k) + delete(authMid.sessions, k) + } + } + } + }() +} + func (authMid *HTTPAuthMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - reqUser, reqPasswd, hasAuth := r.BasicAuth() - if (authMid.user == "" && authMid.passwd == "") || - (hasAuth && util.ConstantTimeEqString(reqUser, authMid.user) && - util.ConstantTimeEqString(reqPasswd, authMid.passwd)) { + if authMid.auth(r) { + next.ServeHTTP(w, r) + } else if authMid.signIn(w, r) { next.ServeHTTP(w, r) } else { if authMid.authFailDelay > 0 {