package api

import (
	"errors"
	"net"
	"net/http"
	"reflect"
	"time"

	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
	grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
	grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
	"github.com/grpc-ecosystem/go-grpc-middleware/ratelimit"
	grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
	grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
	"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
	"github.com/hashicorp/go-multierror"
	"github.com/jbenet/goprocess"
	goprocessctx "github.com/jbenet/goprocess/context"
	"github.com/soheilhy/cmux"
	"github.com/wiryls/pkg/errors/cerrors"
	"google.golang.org/grpc"
	"google.golang.org/grpc/reflection"

	apigrpc "bdware.org/bdledger/pkg/api/grpc"
	apipb "bdware.org/bdledger/pkg/api/grpc/pb"
	"bdware.org/bdledger/pkg/bdservice"
	"bdware.org/bdledger/pkg/common/log"
)

var (
	ErrUnexpectedInternal = errors.New("unexpected internal error")
)

// New creates an empty gRPC server service, no port will be used yet.
func New(
	conf *Conf,
	node apipb.NodeServer,
	ledger apipb.LedgerServer,
	query apipb.QueryServer,
) (srv Service, err error) {

	err = cerrors.TestNilArgumentIfNoErr(err, conf, "conf *api.Conf")
	if err != nil {
		return nil, err
	}

	aLog := log.Get("api")
	gLog := aLog.Named("grpc")
	gdLog := gLog.Desugar()

	// Rate limiting
	itcs := []grpc.UnaryServerInterceptor{
		ratelimit.UnaryServerInterceptor(apigrpc.NewLimiter()),
	}

	// Enable gRPC request logging if set
	if conf.RequestLog {
		opts := []grpc_zap.Option{
			grpc_zap.WithLevels(grpc_zap.DefaultCodeToLevel),
		}
		// make sure we put the grpc_ctxtags context before everything else.
		itcs = append(itcs,
			grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)),
			grpc_zap.UnaryServerInterceptor(gdLog, opts...),
		)
	}
	// Make sure that log statements internal to gRPC library are logged using gLog as well.
	if len(conf.GRPCLogVerbosity) > 0 {
		if lvl, err := conf.GRPCLogVerbosity.ToLogLevel(); err == nil {
			grpc_zap.ReplaceGrpcLoggerV2WithVerbosity(gdLog, int(lvl))
		}
	}
	// Enable access tokens for authentication if set
	if len(conf.AccessTokens) > 0 {
		itcs = append(itcs, grpc_auth.UnaryServerInterceptor(apigrpc.AuthFunc(conf.AccessTokens)))
	}
	// Define customfunc to handle panic and recover.
	// Recovery handlers should typically be last in the chain so that other middleware
	// (e.g. logging) can operate on the recovered state instead of being directly affected by any panic
	// Shared options for the logger, with a custom gRPC code to log level function.
	opts := []grpc_recovery.Option{
		grpc_recovery.WithRecoveryHandler(func(p interface{}) error {
			gLog.DPanicf("panic triggered: %v", p)
			return ErrUnexpectedInternal
		}),
	}
	itcs = append(itcs, grpc_recovery.UnaryServerInterceptor(opts...))

	grpcS := grpc.NewServer(grpc_middleware.WithUnaryServerChain(itcs...))
	apipb.RegisterNodeServer(grpcS, node)
	apipb.RegisterLedgerServer(grpcS, ledger)
	apipb.RegisterQueryServer(grpcS, query)

	s := &service{
		conf:   conf,
		grpcS:  grpcS,
		node:   node,
		ledger: ledger,
		query:  query,
	}
	s.BaseService = bdservice.BaseService{
		Log:    aLog,
		RunFn:  s.run,
		StopFn: s.stop,
	}

	return s, nil
}

type service struct {
	bdservice.BaseService

	conf   *Conf
	grpcS  *grpc.Server
	node   apipb.NodeServer
	ledger apipb.LedgerServer
	query  apipb.QueryServer
	grpcL  net.Listener
	httpL  net.Listener
}

func (s *service) run(proc goprocess.Process) {
	var tcpMux cmux.CMux
	var err error

	// use the same port between gRPC and HTTP APIs
	if s.conf.GRPC.Enabled && s.conf.HTTP.Enabled && (s.conf.HTTP.Addr == "" || s.conf.GRPC.Addr == s.conf.HTTP.Addr) {
		// listen
		if l, err := s.listen(s.conf.GRPC.Addr); err == nil {
			tcpMux = cmux.New(l)
			s.grpcL = tcpMux.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"))
			s.httpL = tcpMux.Match(cmux.HTTP1Fast())
		}
	}

	// enable gRPC API
	if s.conf.GRPC.Enabled {
		// listen
		if s.grpcL == nil {
			if s.grpcL, err = s.listen(s.conf.GRPC.Addr); err != nil {
				return
			}
		}

		// enable gRPC server reflection service
		if s.conf.GRPC.Reflection {
			reflection.Register(s.grpcS)
			s.Log.Info("gRPC server reflection service enabled")
		}

		// enable gRPC API over HTTP/2
		go func() {
			s.Log.Info("Starting gRPC API over HTTP/2")
			if err := s.grpcS.Serve(s.grpcL); err != nil && err != cmux.ErrListenerClosed {
				s.Log.DPanic("gRPC API over HTTP/2 failed with fatal error: ", err)
				s.Merr = multierror.Append(s.Merr, err)
			} else {
				s.Log.Info("gRPC API over HTTP/2 stopped")
			}
		}()
	}

	// enable JSON API over HTTP
	if s.conf.HTTP.Enabled {
		// listen
		if s.httpL == nil {
			if s.httpL, err = s.listen(s.conf.HTTP.Addr); err != nil {
				return
			}
		}

		ctx := goprocessctx.OnClosingContext(proc)
		mux := runtime.NewServeMux()
		{
			err = apipb.RegisterNodeHandlerServer(ctx, mux, s.node)
		}
		if err == nil {
			err = apipb.RegisterLedgerHandlerServer(ctx, mux, s.ledger)
		}
		if err == nil {
			err = apipb.RegisterQueryHandlerServer(ctx, mux, s.query)
		}
		if err != nil {
			s.Log.DPanic("Failed to register the HTTP handlers for service: %s", err)
		}

		go func() {
			s.Log.Info("Starting JSON API over HTTP")
			// always returns a non-nil error
			if err := http.Serve(s.httpL, mux); err != cmux.ErrListenerClosed {
				s.Log.DPanic("JSON API over HTTP failed: ", err)
				s.Merr = multierror.Append(s.Merr, err)
			} else {
				s.Log.Info("JSON API over HTTP stopped: ", err)
			}
		}()
	}

	errc := make(chan error, 1)
	if tcpMux != nil {
		go func() {
			defer close(errc)
			errc <- tcpMux.Serve()
		}()
	}

	select {
	case <-proc.Closing():
	case err = <-errc:
		if err != nil {
			s.Log.DPanic("TCP listener failed: ", err)
			s.Merr = multierror.Append(s.Merr, err)
		} else {
			s.Log.Debug("TCP listener stopped")
		}
	}

	return
}

func (s *service) listen(addr string) (lis net.Listener, err error) {
	if lis, err = net.Listen("tcp", addr); err != nil {
		s.Log.Fatalf("Failed to listen on TCP %s: %s", addr, err)
	} else {
		s.Log.Debugf("Started listening on TCP %s", addr)
	}
	return
}

func (s *service) stop() error {
	if s.grpcL != nil {
		if err := s.grpcL.Close(); err != nil {
			s.Log.DPanic("TCP listener failed to close: ", err)
			s.Merr = multierror.Append(s.Merr, err)
		} else {
			s.Log.Debug("TCP listener closed")
		}
	}

	if s.grpcS != nil {
		stopped := make(chan interface{})
		go func() {
			s.grpcS.GracefulStop()
			close(stopped)
		}()
		select {
		case <-stopped:
			// gracefully stopped, do nothing
			s.Log.Debug("gRPC server stopped")
		case <-time.After(1 * time.Minute):
			// GracefulStop timeout after 1 minute, force stop
			s.Log.Error("gRPC server timed out gracefully stopping, will force close")
			s.grpcS.Stop()
		}
	}

	return s.Merr.ErrorOrNil()
}

func (s *service) Register(regFn interface{}, handler interface{}) {
	in := []reflect.Value{reflect.ValueOf(s.grpcS), reflect.ValueOf(handler)}
	reflect.ValueOf(regFn).Call(in)
}
