diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 9a943c4..720466a 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "net/http" + "net/netip" "os" "path/filepath" "strings" @@ -21,7 +22,6 @@ import ( ) func ServeWeb() http.Handler { - gin.SetMode(gin.ReleaseMode) r := gin.Default() @@ -34,12 +34,39 @@ func ServeWeb() http.Handler { r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler)) } + r.Use(realIp) r.Use(recordPath) routers(r) return r } +func realIp(c *gin.Context) { + if singleton.Conf.RealIPHeader == "" { + c.Next() + return + } + + if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP { + c.Set(model.CtxKeyRealIPStr, c.RemoteIP()) + c.Next() + return + } + + vals := c.Request.Header.Get(singleton.Conf.RealIPHeader) + if vals == "" { + c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"}) + return + } + ip, err := netip.ParseAddr(vals) + if err != nil { + c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()}) + return + } + c.Set(model.CtxKeyRealIPStr, ip.String()) + c.Next() +} + func routers(r *gin.Engine) { authMiddleware, err := jwt.New(initParams()) if err != nil { @@ -127,6 +154,7 @@ func routers(r *gin.Engine) { } func recordPath(c *gin.Context) { + log.Printf("bingo web real ip: %s", c.GetString(model.CtxKeyRealIPStr)) url := c.Request.URL.String() for _, p := range c.Params { url = strings.Replace(url, p.Value, ":"+p.Key, 1) diff --git a/cmd/dashboard/rpc/rpc.go b/cmd/dashboard/rpc/rpc.go index 8aecbfc..ce5f059 100644 --- a/cmd/dashboard/rpc/rpc.go +++ b/cmd/dashboard/rpc/rpc.go @@ -1,11 +1,15 @@ package rpc import ( + "context" "fmt" "net/http" + "net/netip" "time" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" "github.com/hashicorp/go-uuid" "github.com/naiba/nezha/model" @@ -16,12 +20,42 @@ import ( ) func ServeRPC() *grpc.Server { - server := grpc.NewServer() + server := grpc.NewServer(grpc.UnaryInterceptor(getRealIp)) rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler() proto.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton) return server } +func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if singleton.Conf.RealIPHeader == "" { + return handler(ctx, req) + } + + if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP { + p, ok := peer.FromContext(ctx) + if !ok { + return nil, fmt.Errorf("peer not found") + } + addrPort, err := netip.ParseAddrPort(p.Addr.String()) + if err != nil { + return nil, err + } + ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, addrPort.Addr().String()) + return handler(ctx, req) + } + + vals := metadata.ValueFromIncomingContext(ctx, singleton.Conf.RealIPHeader) + if len(vals) == 0 { + return nil, fmt.Errorf("real ip header not found") + } + ip, err := netip.ParseAddr(vals[0]) + if err != nil { + return nil, err + } + ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, ip.String()) + return handler(ctx, req) +} + func DispatchTask(serviceSentinelDispatchBus <-chan model.Service) { workedServerIndex := 0 for task := range serviceSentinelDispatchBus { diff --git a/model/common.go b/model/common.go index 0d0cbf4..10394e9 100644 --- a/model/common.go +++ b/model/common.go @@ -4,7 +4,12 @@ import ( "time" ) -const CtxKeyAuthorizedUser = "ckau" +const ( + CtxKeyAuthorizedUser = "ckau" + CtxKeyRealIPStr = "ckri" +) + +type CtxKeyRealIP struct{} type Common struct { ID uint64 `gorm:"primaryKey" json:"id,omitempty"` diff --git a/model/config.go b/model/config.go index 9887e4a..9557cbb 100644 --- a/model/config.go +++ b/model/config.go @@ -11,12 +11,14 @@ import ( ) const ( - ConfigCoverAll = iota + ConfigUsePeerIP = "NZ::Use-Peer-IP" + ConfigCoverAll = iota ConfigCoverIgnoreAll ) type Config struct { - Debug bool `mapstructure:"debug" json:"debug,omitempty"` // debug模式开关 + Debug bool `mapstructure:"debug" json:"debug,omitempty"` // debug模式开关 + RealIPHeader string `mapstructure:"real_ip_header" json:"real_ip_header,omitempty"` // 真实IP Language string `mapstructure:"language" json:"language"` // 系统语言,默认 zh_CN SiteName string `mapstructure:"site_name" json:"site_name"` diff --git a/service/rpc/auth.go b/service/rpc/auth.go index d1762f1..b965958 100644 --- a/service/rpc/auth.go +++ b/service/rpc/auth.go @@ -2,6 +2,7 @@ package rpc import ( "context" + "log" "sync" "google.golang.org/grpc/codes" @@ -24,6 +25,9 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) { return 0, status.Errorf(codes.Unauthenticated, "获取 metaData 失败") } + realIp := ctx.Value(model.CtxKeyRealIP{}) + log.Printf("bingo rpc realIp: %s, metadata: %v", realIp, md) + var clientSecret string if value, ok := md["client_secret"]; ok { clientSecret = value[0]