package jupyter

import (
	"bytes"
	"context"
	"encoding/binary"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"os"
	"path"
	"runtime/debug"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/google/uuid"
	"github.com/gorilla/websocket"
	"github.com/mattn/go-colorable"
	"go.uber.org/zap"
	"go.uber.org/zap/zapcore"
)

const (
	JavascriptISOString = "2006-01-02T15:04:05.999Z07:00"
	kernelServiceApi    = "api/kernels"

	KernelConnectionInit KernelConnectionStatus = "initializing" // When the BasicKernelConnection struct is first created.
	KernelConnecting     KernelConnectionStatus = "connecting"   // When we are creating the kernel websocket.
	KernelConnected      KernelConnectionStatus = "connected"    // Once we've connected.
	KernelDisconnected   KernelConnectionStatus = "disconnected" // We're not connected to the kernel, but we're unsure if it is dead or not.
	KernelDead           KernelConnectionStatus = "dead"         // Kernel is dead. We're not connected.

	ExecuteRequest          MessageType = "execute_request"
	ExecuteReply            MessageType = "execute_reply"
	KernelInfoRequest       MessageType = "kernel_info_request"
	StopRunningTrainingCode MessageType = "stop_running_training_code_request"
	DummyMessage            MessageType = "dummy_message_request"
	AckMessage              MessageType = "ACK"
	CommCloseMessage        MessageType = "comm_close"

	// KernelIdMetadataKey is a reserved metadata key, meaning it cannot be overwritten in the kernel's
	// metadata dictionary.
	KernelIdMetadataKey = "kernel_id"

	// SendTimestampMetadataKey is a reserved metadata key, meaning it cannot be overwritten in the kernel's
	// metadata dictionary.
	SendTimestampMetadataKey = "send_timestamp_unix_milli"
)

var (
	ErrWebsocketAlreadySetup   = errors.New("the kernel connection's websocket has already been setup")
	ErrWebsocketCreationFailed = errors.New("creation of websocket connection to kernel has failed")
	ErrKernelNotFound          = errors.New("received HTTP 404 status when requesting info for kernel")
	ErrNetworkIssue            = errors.New("received HTTP 503 or HTTP 424 in response to request")
	ErrUnexpectedFailure       = errors.New("the request could not be completed for some unexpected reason")
	ErrKernelIsDead            = errors.New("kernel is dead")
	ErrNotConnected            = errors.New("kernel is not connected")
	ErrCantAckNotRegistered    = errors.New("cannot ACK message as registration for associated channel has not yet completed")
	ErrInitialConnectCompleted = errors.New("the initial connection attempt has already completed successfully for the target kernel")
	ErrSetupInProgress         = errors.New("cannot perform websocket connection setup as another setup procedure is already underway")
	ErrReservedMetadataKey     = errors.New("cannot add metadata using specified key as key is reserved")

	errReconnectionInProgress = errors.New("cannot reconnect to kernel as another reconnection attempt is already underway")
)

type MessageType string

func (t MessageType) String() string {
	return string(t)
}

// If the message type is not of the form "{action}_request" or "{action}_reply", then this will panic.
func (t MessageType) getBaseMessageType() string {
	if strings.HasSuffix(t.String(), "request") {
		return t.String()[0 : len(t.String())-7]
	} else if strings.HasSuffix(t.String(), "reply") {
		return t.String()[0 : len(t.String())-5]
	}

	panic(fmt.Sprintf("Invalid message type: \"%s\"", t))
}

type KernelConnectionStatus string

func (status KernelConnectionStatus) String() string {
	return string(status)
}

type BasicKernelConnection struct {
	logger        *zap.Logger
	sugaredLogger *zap.SugaredLogger
	atom          *zap.AtomicLevel

	// TODO: The response delivery mechanism is wrong. For one, control/shell messages should receive messages on that channel.
	// But if we receive an IOPub message with a parent as a shell message, then the IOPub message is considered to be the response.
	// So, we need to look at parent header request ID, channel type, AND we also need to check the message type.
	// The request is in the form <action>_request and the response is <action>_reply.

	// Register callbacks for responses to particular messages.
	//
	// For now, we only support responses for SHELL and CONTROL messages.
	//
	// Keys for this channel are generated by the 'getResponseChannelKeyX' functions defined in "internal/server/jupyter/utils.go".
	// See the documentation of those functions for additional details.
	responseChannels map[string]chan KernelMessage

	// IOPub message handlers.
	iopubMessageHandlers map[string]IOPubMessageHandler

	messageCount                  int                     // How many messages we've sent. Used when creating message IDs.
	connectionStatus              KernelConnectionStatus  // Connection status with the remote kernel.
	kernelId                      string                  // ID of the associated kernel
	jupyterServerAddress          string                  // Jupyter server IP address
	clientId                      string                  // Jupyter client ID
	username                      string                  // Jupyter username
	webSocket                     *websocket.Conn         // The websocket that is connected to Jupyter
	originalWebsocketCloseHandler func(int, string) error // The original close handler method of the websocket; we replace this with our own, and we call the original from ours.
	model                         *jupyterKernel          // Jupyter kernel model.
	kernelStdout                  []string                // STDOUT history from the kernel, as extracted from IOPub messages.
	kernelStderr                  []string                // STDERR history from the kernel, as extracted from IOPub messages.
	registeredShell               bool                    // True if we've successfully registered our shell channel as a Golang frontend.
	registeredControl             bool                    // True if we've successfully registered our control channel as a Golang frontend.
	waitingForExecuteResponses    atomic.Int32            // waitingForExecuteResponses is the number of active "execute_request" requests that we have active.

	setupInProgress        atomic.Int32
	reconnectionInProgress atomic.Int32

	// metadata is a map containing basic metadata used for labeling kernelMetricsManager.
	metadata map[string]interface{}
	// serializedMetadata contains the same mapping as the metadata mapping, except that the values are serialized
	// to JSON before being added to serializedMetadata. This is so we don't have to perform the serialization step
	// every time we want to embed the metadata in a kernel message.
	serializedMetadata map[string]string
	// metadataMutex ensures atomic access to the metadata and serializedMetadata mappings.
	metadataMutex sync.Mutex

	// Gorilla Websockets support 1 concurrent reader and 1 concurrent writer on the same websocket.
	// What this means is that we can read from the websocket with one goroutine while we write to the websocket with another goroutine.
	// However, we cannot have > 1 goroutines reading at the same time, nor can we have > 1 goroutines write at the same time.
	// So, we have two locks: one for reading, and one for writing.

	mu                sync.Mutex // Internal mutex, not directly related/used for operations on the underlying websocket itself.
	rlock             sync.Mutex // Synchronizes read operations on the websocket.
	wlock             sync.Mutex // Synchronizes write operations on the websocket.
	iopubHandlerMutex sync.Mutex // Synchronizes access to state related to the IOPub message handlers.

	onError func(err error)

	// Used to publish metrics to Prometheus.
	metricsConsumer MetricsConsumer
}

// NewKernelConnection creates and returns a pointer to a new BasicKernelConnection struct.
//
// The BasicKernelConnection will not be connected until InitialConnect is called.
func NewKernelConnection(kernelId string, clientId string, username string, jupyterServerAddress string,
	atom *zap.AtomicLevel, metricsConsumer MetricsConsumer, onError func(err error)) (*BasicKernelConnection, error) {
	if len(clientId) == 0 {
		clientId = uuid.NewString()
	}

	conn := &BasicKernelConnection{
		clientId:             clientId,
		kernelId:             kernelId,
		username:             username,
		atom:                 atom,
		jupyterServerAddress: jupyterServerAddress,
		connectionStatus:     KernelConnectionInit,
		responseChannels:     make(map[string]chan KernelMessage),
		registeredShell:      false,
		registeredControl:    false,
		messageCount:         0,
		kernelStdout:         make([]string, 0),
		kernelStderr:         make([]string, 0),
		iopubMessageHandlers: make(map[string]IOPubMessageHandler),
		metadata:             make(map[string]interface{}),
		serializedMetadata:   make(map[string]string),
		metricsConsumer:      metricsConsumer,
		onError:              onError,
	}

	zapConfig := zap.NewDevelopmentEncoderConfig()
	zapConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
	core := zapcore.NewCore(zapcore.NewConsoleEncoder(zapConfig), zapcore.AddSync(colorable.NewColorableStdout()), atom)
	logger := zap.New(core, zap.Development())
	if logger == nil {
		panic("failed to create logger for workload driver")
	}

	conn.logger = logger
	conn.sugaredLogger = logger.Sugar()

	err := conn.setupWebsocket(conn.jupyterServerAddress)
	if err != nil {
		conn.logger.Error("Failed to setup websocket for new kernel.", zap.Error(err))
		conn.tryCallOnError(err)
		return nil, err
	}

	return conn, nil
}

func (conn *BasicKernelConnection) tryCallOnError(err error) {
	// We don't need to call the error handler for a legitimate internal error.
	if conn.onError != nil && !strings.Contains(err.Error(), "insufficient hosts available") {
		conn.onError(err)
	}
}

func (conn *BasicKernelConnection) SetOnError(onError func(err error)) {
	conn.onError = onError
}

// AddMetadata attaches some metadata to the BasicKernelConnection.
//
// This particular implementation of AddMetadata is thread-safe.
func (conn *BasicKernelConnection) AddMetadata(key string, value interface{}) error {
	conn.metadataMutex.Lock()
	defer conn.metadataMutex.Unlock()

	if key == KernelIdMetadataKey || key == SendTimestampMetadataKey {
		conn.logger.Error("Cannot add requested entry to kernel connection metadata. Key is reserved.",
			zap.String("key", key),
			zap.Any("value", value))

		return fmt.Errorf("%w: \"%s\"", ErrReservedMetadataKey, key)
	}

	conn.metadata[key] = value

	serializedValue, err := json.Marshal(value)
	if err != nil {
		conn.logger.Error("Failed to serialize new metadata added to kernel connection.",
			zap.String("kernel_id", conn.kernelId),
			zap.String("metadata_key", key),
			zap.Any("metadata_value", value),
			zap.Error(err))

		return err
	}

	conn.serializedMetadata[key] = string(serializedValue)

	return nil
}

// GetSerializedMetadata retrieves a piece of serialized metadata that may be attached to the BasicKernelConnection.
//
// This particular implementation of GetSerializedMetadata is thread-safe.
func (conn *BasicKernelConnection) GetSerializedMetadata(key string) (interface{}, bool) {
	conn.metadataMutex.Lock()
	defer conn.metadataMutex.Unlock()

	value, ok := conn.serializedMetadata[key]
	if !ok {
		conn.logger.Warn("Could not find metadata with specified key attached to BasicKernelConnection.",
			zap.String("kernel_id", conn.kernelId),
			zap.String("key", key))
	}
	return value, ok
}

// GetMetadata retrieves a piece of metadata that may be attached to the BasicKernelConnection.
//
// This particular implementation of GetMetadata is thread-safe.
func (conn *BasicKernelConnection) GetMetadata(key string) (interface{}, bool) {
	conn.metadataMutex.Lock()
	defer conn.metadataMutex.Unlock()

	value, ok := conn.metadata[key]
	if !ok {
		conn.logger.Warn("Could not find metadata with specified key attached to BasicKernelConnection.", zap.String("key", key))
	}
	return value, ok
}

// Stdout returns the slice of stdout messages received by the BasicKernelConnection.
func (conn *BasicKernelConnection) Stdout() []string {
	return conn.kernelStdout
}

// Stderr returns the slice of stderr messages received by the BasicKernelConnection.
func (conn *BasicKernelConnection) Stderr() []string {
	return conn.kernelStderr
}

func (conn *BasicKernelConnection) waitForResponseWithTimeout(responseChan chan KernelMessage, timeoutInterval time.Duration, messageType MessageType) (KernelMessage, error) {
	st := time.Now()
	ctx, cancel := context.WithTimeout(context.Background(), timeoutInterval)
	defer cancel()

	conn.logger.Debug("Awaiting response from kernel.",
		zap.String("request_message_type", messageType.String()),
		zap.String("kernel_id", conn.kernelId))

	select {
	case <-ctx.Done():
		{
			err := ctx.Err()
			conn.logger.Error("Timeout while waiting for response from kernel.",
				zap.String("request_message_type", messageType.String()),
				zap.String("kernel_id", conn.kernelId),
				zap.Duration("timeout_interval", timeoutInterval),
				zap.Error(err))
			return nil, err
		}
	case resp := <-responseChan:
		{
			conn.logger.Debug("Received response from kernel.",
				zap.String("response_message_type", resp.GetHeader().MessageType.String()),
				zap.String("request_message_type", messageType.String()),
				zap.String("kernel_id", conn.kernelId),
				zap.Duration("time_elapsed", time.Since(st)))
			return resp, nil
		}
	}
}

// RegisterIoPubHandler registers a handler/consumer of IOPub messages under a specific ID.
func (conn *BasicKernelConnection) RegisterIoPubHandler(id string, handler IOPubMessageHandler) error {
	conn.iopubHandlerMutex.Lock()
	defer conn.iopubHandlerMutex.Unlock()

	if _, ok := conn.iopubMessageHandlers[id]; ok {
		conn.logger.Error("Could not register IOPub message handler.", zap.String("id", id), zap.Error(ErrHandlerAlreadyExists))
		return ErrHandlerAlreadyExists
	}

	conn.iopubMessageHandlers[id] = handler
	conn.logger.Debug("Registered IOPub message handler.", zap.String("id", id))
	return nil
}

// UnregisterIoPubHandler unregisters a handler/consumer of IOPub messages that was registered under the specified ID.
func (conn *BasicKernelConnection) UnregisterIoPubHandler(id string) error {
	conn.iopubHandlerMutex.Lock()
	defer conn.iopubHandlerMutex.Unlock()

	if _, ok := conn.iopubMessageHandlers[id]; !ok {
		conn.logger.Error("Could not unregister IOPub message handler.", zap.String("id", id), zap.Error(ErrNoHandlerFound))
		return ErrNoHandlerFound
	}

	delete(conn.iopubMessageHandlers, id)
	conn.logger.Debug("Unregistered IOPub message handler.", zap.String("id", id))
	return nil
}

func (conn *BasicKernelConnection) SendDummyMessage(channel KernelSocketChannel, content interface{}, waitForResponse bool) (KernelMessage, error) {
	message, responseChan := conn.createKernelMessage(DummyMessage, channel, content)
	err := conn.sendMessage(message)
	if err != nil {
		conn.logger.Error("Error while writing `dummy_message` message.", zap.String("kernel_id", conn.kernelId), zap.Error(err))
		return nil, err
	}

	if waitForResponse {
		return conn.waitForResponseWithTimeout(responseChan, time.Second*20, DummyMessage)
	} else {
		return nil, nil
	}
}

// StopRunningTrainingCode sends a 'stop_running_training_code_request' message.
func (conn *BasicKernelConnection) StopRunningTrainingCode(waitForResponse bool) error {
	message, responseChan := conn.createKernelMessage(StopRunningTrainingCode, ControlChannel, nil)

	err := conn.sendMessage(message)
	if err != nil {
		conn.logger.Error("Error while writing 'stop_running_training_code_request' message.", zap.String("kernel_id", conn.kernelId), zap.Error(err))
		return err
	}

	if waitForResponse {
		_, err := conn.waitForResponseWithTimeout(responseChan, time.Second*20, StopRunningTrainingCode)

		if err != nil {
			conn.logger.Warn("Sending 'dummy' control request to see if we receive a response, seeing as our 'stop_running_training_code_request' request timed-out...", zap.String("kernel_id", conn.kernelId))
			dummyResp, dummyErr := conn.SendDummyMessage(ControlChannel, nil, true)

			if dummyErr != nil {
				conn.logger.Error("'dummy_message' request failed as well (in addition to the failed 'stop_running_training_code_request' request).'", zap.String("kernel_id", conn.kernelId), zap.Error(dummyErr))
			} else {
				conn.logger.Warn("Successfully received response to 'dummy' request.", zap.Any("dummy-response", dummyResp))
			}

			return err // Return the original error.
		}

		// This will be nil if we successfully received a response.
		return err
	}

	return nil
}

// sendAck sends an ACK to the Jupyter Server (and subsequently the Cluster Gateway).
// It returns the address of the Jupyter Server associated with this kernel.
func (conn *BasicKernelConnection) sendAck(msg KernelMessage, channel KernelSocketChannel) error {
	conn.logger.Debug("Attempting to ACK message.", zap.String("message_id", msg.GetHeader().MessageId), zap.String("channel", string(msg.GetChannel())), zap.String("kernel_id", conn.kernelId))

	if channel != ShellChannel && channel != ControlChannel {
		conn.sugaredLogger.Warnf("Cannot ACK message of type \"%s\"...", channel)
	}

	if (channel == ShellChannel && !conn.registeredShell) || (channel == ControlChannel && !conn.registeredControl) {
		conn.sugaredLogger.Warnf("Cannot ACK '%s' '%s' message '%s' as %s channel registration has not yet completed.", channel, msg.GetHeader().MessageType, msg.GetHeader().MessageId, channel)
		return fmt.Errorf("%w: %s", ErrCantAckNotRegistered, channel)
	}

	var content = make(map[string]interface{})
	content["sender-identity"] = fmt.Sprintf("GoJupyter-%s", conn.kernelId)

	ackMessage, _ := conn.createKernelMessage(AckMessage, channel, content)
	ackMessage.(*Message).ParentHeader = msg.GetParentHeader()

	firstPart := fmt.Sprintf(LightBlueStyle.Render("Sending ACK for %v \"%v\""), channel, msg.GetParentHeader().MessageType)
	secondPart := fmt.Sprintf("(MsgId=%v)", LightPurpleStyle.Render(msg.GetParentHeader().MessageId))
	thirdPart := fmt.Sprintf(LightBlueStyle.Render("message: %v"), ackMessage)
	conn.sugaredLogger.Debugf("%s %s %s", firstPart, secondPart, thirdPart)

	err := conn.sendMessage(ackMessage)
	if err != nil {
		conn.logger.Error("Error while writing 'ACK' message.", zap.String("kernel_id", conn.kernelId), zap.Error(err))
		return err
	}

	return nil
}

// JupyterServerAddress returns the address of the Jupyter Server associated with this kernel.
func (conn *BasicKernelConnection) JupyterServerAddress() string {
	return conn.jupyterServerAddress
}

// Connected returns true if the connection is currently active.
func (conn *BasicKernelConnection) Connected() bool {
	return conn.connectionStatus == KernelConnected
}

// ConnectionStatus returns the connection status of the kernel.
func (conn *BasicKernelConnection) ConnectionStatus() KernelConnectionStatus {
	return conn.connectionStatus
}

// KernelId returns the ID of the kernel itself.
func (conn *BasicKernelConnection) KernelId() string {
	return conn.kernelId
}

// handleExecuteRequestResponse handles the response received from an "execute_request" message.
func (conn *BasicKernelConnection) handleExecuteRequestResponse(request KernelMessage, requestArgs *RequestExecuteArgs, responseChan chan KernelMessage, sentAt time.Time) (response KernelMessage) {
	// We'll populate this either in the ticker or when we get the response.
	response = <-responseChan
	latency := time.Since(sentAt)
	conn.logger.Debug("Received response to `execute_request` message.",
		zap.String("kernel_id", conn.kernelId),
		zap.String("message_id", request.GetHeader().MessageId),
		zap.Duration("latency", latency),
		zap.Any("response", response))

	conn.waitingForExecuteResponses.Add(-1)

	// If we haven't populated the workloadId variable with a value yet, then attempt to do so.
	var workloadId string
	val, loaded := conn.GetMetadata(WorkloadIdMetadataKey)

	if loaded {
		workloadId = val.(string)
	} else {
		conn.logger.Warn("Could not load WorkloadId metadata from KernelConnection.",
			zap.String("kernel_id", conn.kernelId),
			zap.Int("num_metadata_entries", len(conn.metadata)))
	}

	if conn.metricsConsumer != nil && workloadId != "" {
		latencyMs := latency.Milliseconds()
		conn.metricsConsumer.ObserveJupyterExecuteRequestE2ELatency(latencyMs, workloadId)
		conn.metricsConsumer.AddJupyterRequestExecuteTime(latencyMs, conn.kernelId, workloadId)
	}

	if requestArgs.ExtraArguments != nil && requestArgs.ExtraArguments.ResponseCallback != nil {
		conn.logger.Debug("Calling ResponseCallback for \"execute_reply\" message.",
			zap.String("kernel_id", conn.kernelId),
			zap.String("request_id", request.GetHeader().MessageId),
			zap.String("reply_id", response.GetHeader().MessageId),
			zap.Duration("latency", latency))
		requestArgs.ExtraArguments.ResponseCallback(response)
	}

	return response
}

// CreateExecuteRequestMessage creates an "execute_request" message and an associated response channel
// from the given RequestExecuteArgs.
func (conn *BasicKernelConnection) CreateExecuteRequestMessage(args *RequestExecuteArgs) (KernelMessage, chan KernelMessage) {
	content := args.StripNonstandardArguments()

	message, responseChan := conn.createKernelMessage(ExecuteRequest, ShellChannel, content)

	if args.ExtraArguments != nil && args.ExtraArguments.RequestMetadata != nil {
		for key, value := range args.ExtraArguments.RequestMetadata {
			conn.logger.Debug("Adding metadata entry to \"execute_request\" message.",
				zap.String("kernel_id", conn.kernelId),
				zap.String("message_id", message.GetHeader().MessageId),
				zap.String("metadata_key", key),
				zap.Any("metadata_value", value))

			message.AddMetadata(key, value)
		}
	}

	return message, responseChan
}

// RequestExecute sends an `execute_request` message.
//
// #### Notes
// See [Messaging in Jupyter](https://jupyter-client.readthedocs.io/en/latest/messaging.html#execute).
//
// Future `onReply` is called with the `execute_reply` content when the shell reply is received and validated.
// The future will resolve when this message is received and the `idle` iopub status is received.
//
// Arguments:
// - code (string): The code to execute.
// - silent (bool): Whether to execute the code as quietly as possible. The default is `false`.
// - storeHistory (bool): Whether to store history of the execution. The default `true` if silent is False. It is forced to  `false ` if silent is `true`.
// - userExpressions (map[string]interface{}): A mapping of names to expressions to be evaluated in the kernel's interactive namespace.
// - allowStdin (bool): Whether to allow stdin requests. The default is `true`.
// - stopOnError (bool): Whether to the abort execution queue on an error. The default is `false`.
// - waitForResponse (bool): Whether to wait for a response from the kernel, or just return immediately.
//
// Returns the response, the ID of the execute_request, and an error.
func (conn *BasicKernelConnection) RequestExecute(args *RequestExecuteArgs) (KernelMessage, string, error) {
	message, responseChan := conn.CreateExecuteRequestMessage(args)

	sentAt := time.Now()
	err := conn.sendMessage(message)
	if err != nil {
		conn.logger.Error("Error while writing 'execute_request' message.",
			zap.String("kernel_id", conn.kernelId),
			zap.Error(err))

		return nil, message.GetHeader().MessageId, err
	}

	conn.waitingForExecuteResponses.Add(1)

	if args.AwaitResponse() {
		return conn.handleExecuteRequestResponse(message, args, responseChan, sentAt), message.GetHeader().MessageId, nil // blocking
	} else {
		go conn.handleExecuteRequestResponse(message, args, responseChan, sentAt) // non-blocking
	}

	return nil, message.GetHeader().MessageId, nil
}

func (conn *BasicKernelConnection) RequestKernelInfo() (KernelMessage, error) {
	content := make(map[string]interface{})
	content["sender-id"] = fmt.Sprintf("GoJupyter-%s", conn.kernelId)

	message, responseChan := conn.createKernelMessage(KernelInfoRequest, ShellChannel, content)

	conn.logger.Debug("Sending 'request-info' message now.",
		zap.String("message_id", message.GetHeader().MessageId),
		zap.String("kernel_id", conn.kernelId),
		zap.String("session", message.GetHeader().Session),
		zap.String("message", message.String()))

	err := conn.sendMessage(message)
	if err != nil {
		return nil, err
	}

	timeout := time.Second * time.Duration(5)
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	defer cancel()

	select {
	case <-ctx.Done():
		{
			conn.logger.Error("Request of type \"kernel_info_request\" has timed out.",
				zap.String("kernel_id", conn.kernelId), zap.String("message_id", message.GetHeader().MessageId))
			return nil, fmt.Errorf("ErrRequestTimedOut %w : %s", ErrRequestTimedOut, ctx.Err())
		}
	case resp := <-responseChan:
		{
			conn.logger.Debug("Received response to 'request-info' request.",
				zap.String("kernel_id", conn.kernelId),
				zap.String("message_id", resp.GetHeader().MessageId),
				zap.String("response", resp.String()))
			return resp, nil
		}
	}
}

// InterruptKernel interrupts a kernel.
//
// #### Notes
// Uses the [Jupyter Server API](https://petstore.swagger.io/?url=https://raw.githubusercontent.com/jupyter-server/jupyter_server/main/jupyter_server/services/api/api.yaml#!/kernels).
//
// The promise is fulfilled on a valid response and rejected otherwise.
//
// It is assumed that the API call does not mutate the kernel id or name.
//
// The promise will be rejected if the kernel status is `Dead` or if the
// request fails or the response is invalid.
func (conn *BasicKernelConnection) InterruptKernel() error {
	if conn.connectionStatus == KernelDead {
		// Cannot interrupt a dead kernel.
		return fmt.Errorf("%w: no connection to kernel \"%s\"", ErrKernelIsDead, conn.kernelId)
	}

	conn.logger.Debug("Attempting to Interrupt kernel.", zap.String("kernel_id", conn.kernelId))

	var requestBody = make(map[string]interface{})
	requestBody["kernel_id"] = conn.kernelId

	requestBodyEncoded, err := json.Marshal(requestBody)
	if err != nil {
		conn.logger.Error("Failed to marshal request body for kernel interruption request", zap.Error(err))
		return err
	}

	endpoint := path.Join(conn.jupyterServerAddress, fmt.Sprintf("/api/kernels/%s/interrupt", conn.kernelId))
	// endpoint := fmt.Sprintf("%s/api/kernels/%s/interrupt", conn.jupyterServerAddress, conn.kernelId)
	req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(requestBodyEncoded))

	if err != nil {
		conn.logger.Error("Failed to create HTTP request for kernel interruption.", zap.String("url", endpoint), zap.Error(err))
		return err
	}

	client := &http.Client{}
	resp, err := client.Do(req)
	if err != nil {
		conn.logger.Error("Error while issuing HTTP request to interrupt kernel.", zap.String("url", endpoint), zap.Error(err))
		return err
	}

	data, err := io.ReadAll(resp.Body)
	if err != nil {
		conn.logger.Error("Failed to read response to interrupting kernel.", zap.Error(err))
		return err
	}

	conn.logger.Debug("Received response to interruption request.", zap.Int("status-code", resp.StatusCode), zap.String("status", resp.Status), zap.Any("response", data))
	return nil
}

// Close the connection to the kernel.
func (conn *BasicKernelConnection) Close() error {
	message, _ := conn.createKernelMessage(CommCloseMessage, ShellChannel, nil)
	err := conn.sendMessage(message)

	if err != nil {
		conn.logger.Error("Failed to send 'comm_closed' message to kernel.", zap.String("kernel_id", conn.kernelId), zap.String("error-message", err.Error()))
	}

	conn.logger.Warn("Closing WebSocket connection to kernel now.", zap.String("kernel_id", conn.kernelId))
	err = conn.webSocket.Close()

	if err != nil {
		conn.logger.Error("Error while closing WebSocket connection to kernel.",
			zap.String("kernel_id", conn.kernelId),
			zap.String("client_id", conn.clientId),
			zap.String("username", conn.username),
			zap.Error(err))
	}

	return err // Will be nil on success.
}

func (conn *BasicKernelConnection) ClientId() string {
	return conn.clientId
}

func (conn *BasicKernelConnection) Username() string {
	return conn.username
}

// decodeKernelMessage decodes a kernel message according to the WebSocket protocol and returns the
// decoded result, or an error if one occurred.
//
// - WebSocket protocol: https://jupyter-server.readthedocs.io/en/latest/developers/websocket-protocols.html
func (conn *BasicKernelConnection) decodeKernelMessage(buf []byte) (*Message, error) {
	//conn.logger.Debug("Decoding message from kernel.",
	//	zap.String("kernel_id", conn.kernelId),
	//	zap.Int("size", len(buf)),
	//	zap.Binary("message_as_binary", buf),
	//	zap.Any("message_as_any", buf),
	//	zap.ByteString("message_as_utf8", buf))

	// Decode JSON from the buffer.
	// This will usually work, but if the message has buffers attached to it, then it won't.
	var msg *Message
	if err := json.Unmarshal(buf, &msg); err == nil {
		//conn.logger.Debug("Successfully deserialized Jupyter message without having to parse any of it as binary.",
		//	zap.String("message_id", msg.Header.MessageId),
		//	zap.String("message_type", msg.Header.MessageType.String()),
		//	zap.String("kernel_id", conn.kernelId))
		return msg, nil
	}

	//conn.logger.Debug("Failed to parse Jupyter message directly. Parsing binary now.",
	//	zap.Int("size", len(buf)), zap.String("kernel_id", conn.kernelId))

	if len(buf) < 4 {
		conn.logger.Error("Jupyter message is invalid -- too small.",
			zap.Int("size", len(buf)), zap.String("kernel_id", conn.kernelId))
		return nil, errors.New("buffer too small for header")
	}

	// Read the number of buffers
	numBuffers := int(binary.BigEndian.Uint32(buf[0:4]))
	if numBuffers < 2 {
		return nil, errors.New("invalid incoming Kernel Message")
	}

	// Read the offsets
	offsets := make([]int, numBuffers)
	for i := 1; i <= numBuffers; i++ {
		offset := int(binary.BigEndian.Uint32(buf[i*4 : (i+1)*4]))
		offsets[i-1] = offset
	}

	// Decode JSON from the buffer
	jsonBytes := buf[offsets[0]:offsets[1]]
	if err := json.Unmarshal(jsonBytes, &msg); err != nil {
		return nil, err
	}

	// Collect buffers
	msg.Buffers = make([][]byte, numBuffers-1)
	for i := 1; i < numBuffers; i++ {
		start := offsets[i]
		stop := len(buf)
		if i+1 < numBuffers {
			stop = offsets[i+1]
		}
		msg.Buffers[i-1] = buf[start:stop]
	}

	//conn.logger.Debug("Successfully decoded binary message from kernel.",
	//	zap.String("kernel_id", conn.kernelId),
	//	zap.String("message_id", msg.Header.MessageId),
	//	zap.String("message_type", msg.Header.MessageType.String()),
	//	zap.Int("num_buffers", numBuffers),
	//	zap.Ints("buffer_offsets", offsets))

	return msg, nil
}

// Listen for messages from the kernel.
func (conn *BasicKernelConnection) serveMessages() {
	for {
		conn.rlock.Lock()
		messageType, data, err := conn.webSocket.ReadMessage()
		conn.rlock.Unlock()

		if err != nil {
			conn.logger.Error("Error while reading from kernel WebSocket.",
				zap.String("kernel_id", conn.kernelId),
				zap.String("client_id", conn.clientId),
				zap.String("username", conn.username),
				zap.Int("websocket_message_type", messageType),
				zap.ByteString("data_byte_string", data),
				zap.Binary("data_binary", data),
				zap.Error(err))
			st := time.Now()
			connErr := conn.updateConnectionStatus(KernelDead)

			if !conn.Connected() || connErr != nil {
				conn.logger.Error("Failed to re-establish connection with kernel.",
					zap.String("kernel_id", conn.kernelId),
					zap.String("client_id", conn.clientId),
					zap.String("username", conn.username),
					zap.Error(connErr))

				return
			}

			conn.logger.Debug("Successfully re-established WebSocket connection to kernel following connection loss.",
				zap.String("kernel_id", conn.kernelId),
				zap.String("client_id", conn.clientId),
				zap.String("username", conn.username),
				zap.Duration("time_elapsed", time.Since(st)))

			continue
		}

		var kernelMessage *Message
		kernelMessage, err = conn.decodeKernelMessage(data)

		if err == io.EOF {
			conn.logger.Error("Got EOF while trying to decode message from kernel.",
				zap.String("kernel_id", conn.kernelId),
				zap.String("client_id", conn.clientId),
				zap.String("username", conn.username),
				zap.Int("websocket_message_type", messageType))

			// One value is expected in the message.
			err = io.ErrUnexpectedEOF
			continue
		} else if err != nil {
			conn.logger.Error("Failed to decode message from kernel using JSON.",
				zap.String("kernel_id", conn.kernelId),
				zap.String("client_id", conn.clientId),
				zap.String("username", conn.username),
				zap.Int("websocket_message_type", messageType),
				zap.ByteString("data_byte_string", data),
				zap.Binary("data_binary", data),
				zap.Any("raw_data", data),
				zap.Error(err))
			continue
		}

		// We send ACKs for Shell and Control messages.
		// We will also attempt to pair the message with its original request.
		if kernelMessage.Channel == ShellChannel || kernelMessage.Channel == ControlChannel {
			conn.logger.Debug("Received message from kernel.",
				zap.String("kernel_id", conn.kernelId),
				zap.String("client_id", conn.clientId),
				zap.String("username", conn.username),
				zap.String("channel", kernelMessage.Channel.String()),
				zap.String("message_type", kernelMessage.Header.MessageType.String()),
				zap.String("message_id", kernelMessage.Header.MessageId),
				zap.String("message", kernelMessage.String()))

			// Commented-out; for now, we're not ACK-ing anything.
			// We do this in another goroutine so as not to block this message-receiver goroutine.
			// go conn.sendAck(kernelMessage, kernelMessage.Channel)

			responseChannelKey := getResponseChannelKeyFromReply(kernelMessage)
			if responseChannel, ok := conn.responseChannels[responseChannelKey]; ok {
				conn.logger.Debug("Found response channel for websocket message.",
					zap.String("request_message_id", kernelMessage.GetParentHeader().MessageId),
					zap.String("response_message_id", kernelMessage.GetHeader().MessageId),
					zap.String("message_type", string(kernelMessage.Header.MessageType)),
					zap.String("channel", kernelMessage.Channel.String()),
					zap.String("response_channel_key", responseChannelKey),
					zap.String("kernel_id", conn.kernelId),
					zap.String("client_id", conn.clientId),
					zap.String("username", conn.username))
				responseChannel <- kernelMessage
				conn.logger.Debug("Response delivered (via channel) for websocket message.",
					zap.String("request_message_id", kernelMessage.GetParentHeader().MessageId),
					zap.String("response_message_id", kernelMessage.GetHeader().MessageId),
					zap.String("response_message_type", kernelMessage.GetHeader().MessageType.String()),
					zap.String("response_session", kernelMessage.GetHeader().Session))
			} else {
				conn.logger.Warn("Could not find response channel associated with message.",
					zap.String("request_message_id", kernelMessage.GetParentHeader().MessageId),
					zap.String("response_message_id", kernelMessage.GetHeader().MessageId),
					zap.String("message_type", string(kernelMessage.Header.MessageType)),
					zap.String("channel", kernelMessage.Channel.String()),
					zap.String("response_channel_key", responseChannelKey),
					zap.String("kernel_id", conn.kernelId),
					zap.String("client_id", conn.clientId),
					zap.String("username", conn.username))
			}
		} else {
			if kernelMessage.Channel == IOPubChannel {
				// TODO: Make it so we can query/view all of the output generated by a Session via the Workload Driver console/frontend.
				conn.handleIOPubMessage(kernelMessage)
			}
		}
	}
}

func (conn *BasicKernelConnection) handleIOPubMessage(kernelMessage KernelMessage) {
	conn.iopubHandlerMutex.Lock()
	defer conn.iopubHandlerMutex.Unlock()

	// If there are no handlers registered, then just invoke the default IOPub message handler.
	if len(conn.iopubMessageHandlers) == 0 {
		go conn.defaultHandleIOPubMessage(kernelMessage)
		return
	}

	// Otherwise, invoke the message handlers.
	for _, handler := range conn.iopubMessageHandlers {
		go handler(conn, kernelMessage)
	}
}

// defaultHandleIOPubMessage provides a default handler for IOPub messages.
// This extracts stream IOPub messages and stores them within the kernel connection struct.
//
// Important: this will be called in its own goroutine.
//
// Also, this function does not match the definition of the 'IOPubMessageHandler' type.
func (conn *BasicKernelConnection) defaultHandleIOPubMessage(kernelMessage KernelMessage) {
	// We just want to extract the output from 'stream' IOPub messages.
	// We don't care about non-stream-type IOPub messages here, so we'll just return.
	if kernelMessage.GetHeader().MessageType != "stream" {
		return
	}

	content := kernelMessage.GetContent().(map[string]interface{})

	var (
		stream string
		text   string
		ok     bool
	)

	stream, ok = content["name"].(string)
	if !ok {
		conn.logger.Warn("Content of IOPub message did not contain an entry with key \"name\" and value of type string.",
			zap.Any("content", content), zap.Any("message", kernelMessage), zap.String("kernel_id", conn.kernelId),
			zap.String("client_id", conn.clientId), zap.String("username", conn.username))
		return
	}

	text, ok = content["text"].(string)
	if !ok {
		conn.logger.Warn("Content of IOPub message did not contain an entry with key \"text\" and value of type string.",
			zap.Any("content", content), zap.Any("message", kernelMessage), zap.String("kernel_id", conn.kernelId),
			zap.String("client_id", conn.clientId), zap.String("username", conn.username))
		return
	}

	switch stream {
	case "stdout":
		{
			conn.kernelStdout = append(conn.kernelStdout, text)
		}
	case "stderr":
		{
			conn.kernelStderr = append(conn.kernelStdout, text)
		}
	default:
		conn.logger.Error("Unknown or unsupported stream found in IOPub message.",
			zap.String("stream", stream), zap.String("kernel_id", conn.kernelId), zap.Any("message", kernelMessage),
			zap.String("client_id", conn.clientId), zap.String("username", conn.username))
	}

	return
}

// RangeOverMap iterates over a map with arbitrary key/value types.
func RangeOverMap[K comparable, V any](f func(key K, value V) bool, dict map[K]V) {
	for key, value := range dict {
		shouldContinue := f(key, value)

		if !shouldContinue {
			return
		}
	}
}

// RangeOverSerializedMetadata atomically iterates over the BasicKernelConnection's serialized metadata mapping, calling
// the provided function for each key-value pair in the metadata mapping.
//
// If the provided function returns false, then the iteration will stop, and RangeOverMetadata will return.
func (conn *BasicKernelConnection) RangeOverSerializedMetadata(f func(key string, value string) bool) {
	conn.metadataMutex.Lock()
	defer conn.metadataMutex.Unlock()

	RangeOverMap[string, string](f, conn.serializedMetadata)
}

// RangeOverMetadata atomically iterates over the BasicKernelConnection's metadata mapping, calling
// the provided function for each key-value pair in the metadata mapping.
//
// If the provided function returns false, then the iteration will stop, and RangeOverMetadata will return.
func (conn *BasicKernelConnection) RangeOverMetadata(f func(key string, value interface{}) bool) {
	conn.metadataMutex.Lock()
	defer conn.metadataMutex.Unlock()

	RangeOverMap[string, interface{}](f, conn.metadata)
}

func (conn *BasicKernelConnection) createKernelMessage(messageType MessageType, channel KernelSocketChannel, content interface{}) (KernelMessage, chan KernelMessage) {
	messageId := conn.getNextMessageId()

	header := NewKernelMessageHeaderBuilder().
		WithDate(time.Now()).
		WithMessageId(messageId).
		WithMessageType(messageType).
		WithSession(conn.clientId).
		WithUsername(conn.username).
		WithVersion(VERSION).
		Build()

	if content == nil {
		content = make(map[string]interface{})
	}

	messageMetadata := make(map[string]interface{})

	// Add all the registered metadata to the dictionary.
	conn.RangeOverSerializedMetadata(func(key string, value string) bool {
		messageMetadata[key] = value
		conn.logger.Debug("Added metadata to Jupyter kernel message.",
			zap.String("message_id", messageId),
			zap.String("message_type", messageType.String()),
			zap.String("metadata_key", key),
			zap.Any("metadata_value", value))

		return true
	})

	conn.RangeOverMetadata(func(key string, value interface{}) bool {
		if _, loaded := messageMetadata[key]; loaded {
			// It's already added. Skip it.
			return true
		}

		marshalledValue, err := json.Marshal(value)
		if err != nil {
			conn.logger.Error("Failed to serialize piece of metadata while creating kernel message.",
				zap.String("message_id", messageId),
				zap.String("message_type", messageType.String()),
				zap.String("metadata_key", key),
				zap.Any("metadata_value", value),
				zap.Error(err))
			return true
		}

		messageMetadata[key] = marshalledValue
		conn.logger.Debug("Added metadata to Jupyter kernel message.",
			zap.String("message_id", messageId),
			zap.String("message_type", messageType.String()),
			zap.String("metadata_key", key),
			zap.Any("metadata_value", value))

		return true
	})

	message := NewMessageBuilder().
		WithChannel(channel).
		WithHeader(header).
		WithContent(content).
		WithMetadataDictionary(messageMetadata).
		WithMetadata(KernelIdMetadataKey, conn.kernelId).
		WithMetadata(SendTimestampMetadataKey, time.Now().UnixMilli()).
		Build()

	var responseChannel chan KernelMessage
	if channel == ShellChannel || channel == ControlChannel {

		// We create a buffered channel so that the 'message-receiver' goroutine cannot get blocked trying to put
		// a result into a response channel for which the receiver is not actively listening/waiting for said response.
		responseChannel = make(chan KernelMessage, 1)
		responseChannelKey := getResponseChannelKeyFromRequest(message)

		if len(responseChannelKey) == 0 {
			conn.logger.Debug("Returning nil response channel.",
				zap.String("message_type", messageType.String()), zap.String("channel", channel.String()),
				zap.String("client_id", conn.clientId), zap.String("username", conn.username))
			return message, nil
		}

		conn.responseChannels[responseChannelKey] = responseChannel

		conn.sugaredLogger.Debugf("Stored response channel for %s \"%s\" message under key \"%s\" for kernel %s.",
			channel, messageType, responseChannelKey, conn.kernelId)
	}

	return message, responseChannel /* Will be nil for messages that are not either Shell or Control */
}

func (conn *BasicKernelConnection) getNextMessageId() string {
	conn.mu.Lock()
	defer conn.mu.Unlock()

	messageId := fmt.Sprintf("%s_%d_%d", conn.clientId, os.Getpid(), conn.messageCount)
	conn.messageCount += 1
	return messageId
}

func (conn *BasicKernelConnection) updateConnectionStatus(status KernelConnectionStatus) error {
	if conn.connectionStatus == status {
		return nil
	}

	conn.connectionStatus = status

	// Send a kernel info request to make sure we send at least one
	// message to get kernel status back. Always request kernel info
	// first, to get kernel status back and ensure iopub is fully
	// established. If we are restarting, this message will skip the queue
	// and be sent immediately.
	success := false
	maxNumTries := 3
	if conn.connectionStatus == KernelConnected {
		conn.logger.Debug("Connection status is being updated to 'connected'. Attempting to retrieve kernel info.",
			zap.String("kernel_id", conn.kernelId))
		st := time.Now()

		numTries := 0

		var statusMessage KernelMessage
		var err error

		for numTries < maxNumTries {
			statusMessage, err = conn.RequestKernelInfo()
			if err != nil {
				numTries += 1
				conn.sugaredLogger.Errorf("Attempt %d/%d to request-info from kernel %s FAILED. Error: %s",
					numTries, maxNumTries, conn.kernelId, err)
				time.Sleep(time.Duration(1.25*float64(numTries)) * time.Second)
				conn.tryCallOnError(err)
				continue
			} else {
				success = true
				conn.logger.Debug("Successfully retrieved kernel info on connected-status-changed.",
					zap.String("kernel-info", statusMessage.String()),
					zap.Duration("time-elapsed", time.Since(st)))
				break
			}
		}

		if !success {
			conn.logger.Error("Failed to issue \"kernel_info_request\" message.",
				zap.String("kernel_id", conn.kernelId),
				zap.Int("num_attempts", maxNumTries))

			conn.connectionStatus = KernelDisconnected

			if err == nil {
				err = fmt.Errorf("failed to issue \"kernel_info_request\" message")
			}

			conn.tryCallOnError(err)
			return err
		}
	}

	conn.sugaredLogger.Debugf("Kernel %s connection status set to '%s'", conn.kernelId, conn.connectionStatus)
	return nil
}

// setupWebsocket sets up the WebSocket connection to the Jupyter Server.
// Side-effect: updates the BasicKernelConnection's `webSocket` field.
func (conn *BasicKernelConnection) setupWebsocket(jupyterServerAddress string) error {
	if !conn.setupInProgress.CompareAndSwap(0, 1) {
		conn.logger.Warn("Cannot setup WebSocket. Another setup procedure is already underway.")
		return ErrSetupInProgress
	}

	if conn.webSocket != nil {
		conn.logger.Warn("Existing WebSocket found. Recreating anyway.", zap.String("kernel_id", conn.kernelId))
		conn.webSocket = nil
		// return ErrWebsocketAlreadySetup
	}

	originalStatus := conn.connectionStatus
	err := conn.updateConnectionStatus(KernelConnecting)
	if err != nil {
		conn.logger.Error("Failed to set kernel connection status.",
			zap.String("initial_status", originalStatus.String()),
			zap.String("target_status", KernelDead.String()),
			zap.String("kernel_id", conn.kernelId))
		conn.tryCallOnError(err)
		return err
	}

	wsUrl := "ws://" + jupyterServerAddress
	idUrl := url.PathEscape(conn.kernelId)

	partialUrl, err := url.JoinPath(wsUrl, kernelServiceApi, idUrl)
	if err != nil {
		conn.logger.Error("Error when creating partial URL.", zap.String("wsUrl", wsUrl), zap.String("kernelServiceApi", kernelServiceApi), zap.String("idUrl", idUrl), zap.Error(err))
		err = fmt.Errorf("ErrWebsocketCreationFailed %w : %s", ErrWebsocketCreationFailed, err.Error())
		conn.tryCallOnError(err)
		return err
	}

	conn.sugaredLogger.Debugf("Created partial kernel websocket URL: '%s'", partialUrl)
	endpoint := partialUrl + "/" + fmt.Sprintf("channels?session_id=%s", url.PathEscape(conn.clientId))

	conn.sugaredLogger.Debugf("Created full kernel websocket URL: '%s'", endpoint)

	st := time.Now()

	dialer := &websocket.Dialer{
		Proxy:            http.ProxyFromEnvironment,
		HandshakeTimeout: 90 * time.Second,
	}
	ws, _, err := dialer.Dial(endpoint, nil)
	if err != nil {
		conn.logger.Error("Failed to dial kernel websocket.", zap.String("endpoint", endpoint), zap.String("kernel_id", conn.kernelId), zap.Error(err))
		err = fmt.Errorf("ErrWebsocketCreationFailed %w : %s", ErrWebsocketCreationFailed, err.Error())
		conn.tryCallOnError(err)
		return err
	}

	conn.logger.Debug("Successfully connected to the kernel.", zap.Duration("time-taken-to-connect", time.Since(st)), zap.String("kernel_id", conn.kernelId))
	conn.webSocket = ws

	go conn.serveMessages()

	// Set up the close handler, which automatically tries to reconnect.
	if conn.originalWebsocketCloseHandler == nil {
		handler := conn.webSocket.CloseHandler()
		conn.originalWebsocketCloseHandler = handler
	}
	conn.webSocket.SetCloseHandler(conn.websocketClosed)

	originalStatus = conn.connectionStatus
	err = conn.updateConnectionStatus(KernelConnected)
	if err != nil {
		conn.logger.Error("Failed to set kernel connection status.",
			zap.String("initial_status", originalStatus.String()),
			zap.String("target_status", KernelDead.String()),
			zap.String("kernel_id", conn.kernelId))
		conn.tryCallOnError(err)
		return err
	}

	// Skip for now... we may or may not need this.
	// The registration idea was so we could figure out a way to add support for ACKs between the Cluster Gateway and the Golang Jupyter frontends.
	// conn.registerAsGolangFrontend()

	if !conn.setupInProgress.CompareAndSwap(1, 0) {
		panic("CompareAndSwap should have swapped.")
	}

	return nil
}

func (conn *BasicKernelConnection) websocketClosed(code int, text string) error {
	if conn.originalWebsocketCloseHandler == nil {
		panic("Original websocket close-handler is not set.")
	}

	conn.logger.Warn("WebSocket::Closed called.", zap.String("kernel_id", conn.kernelId), zap.Int("code", code), zap.String("text", text))
	debug.PrintStack()

	// Try to get the model.
	model, err := conn.getKernelModel()
	if err != nil {
		conn.logger.Error("Exception encountered while trying to retrieve kernel model.",
			zap.String("kernel_id", conn.kernelId), zap.Error(err))

		if errors.Is(err, ErrNetworkIssue) {
			reconnected, reconnectionAttempted := conn.reconnect()
			if !reconnectionAttempted {
				// Error is only non-nil if reconnect could not be attempted due to another concurrent reconnection attempt.
				// So, let the other reconnection attempt handle this. We'll just return nil.
				return nil
			}

			if reconnected {
				// If it was a network error, and we were able to reconnect, then exit the 'websocket closed' handler.
				return nil
			}
		}

		originalStatus := conn.connectionStatus
		// If it was not a network error, or it was, but we failed to reconnect, then call the original 'websocket closed' handler.
		err = conn.updateConnectionStatus(KernelDead)
		if err != nil {
			conn.logger.Error("Failed to set kernel connection status.",
				zap.String("initial_status", originalStatus.String()),
				zap.String("target_status", KernelDead.String()),
				zap.String("kernel_id", conn.kernelId))
		}

		return conn.originalWebsocketCloseHandler(code, text)
	}

	// If we get the model and the execution state is dead, then we terminate.
	// If we get the model and the execution state is NOT dead, then we try to reconnect.
	conn.model = model
	if model.ExecutionState == string(KernelDead) {
		// Kernel is dead. Call the original 'websocket closed' handler.
		originalStatus := conn.connectionStatus
		err = conn.updateConnectionStatus(KernelDead)
		if err != nil {
			conn.logger.Error("Failed to set kernel connection status.",
				zap.String("initial_status", originalStatus.String()),
				zap.String("target_status", KernelDead.String()),
				zap.String("kernel_id", conn.kernelId))
		}

		return conn.originalWebsocketCloseHandler(code, text)
	} else {
		success, reconnectionAttempted := conn.reconnect()
		if !reconnectionAttempted {
			// Error is only non-nil if reconnect could not be attempted due to another concurrent reconnection attempt.
			// So, let the other reconnection attempt handle this. We'll just return nil.
			return nil
		}

		// If we reconnected, then just return. If we failed to reconnect, call the original 'websocket closed' handler.
		if success {
			return nil
		} else {
			return conn.originalWebsocketCloseHandler(code, text)
		}
	}
}

// reconnect attempts to reconnect to the kernel.
// The first boolean returned indicates whether the reconnection was successful.
// The second boolean returned indicates whether the reconnection was attempted.
// If there is already another reconnect attempt underway, then this call to reconnect will return immediately.
func (conn *BasicKernelConnection) reconnect() (bool, bool) {
	if !conn.reconnectionInProgress.CompareAndSwap(0, 1) {
		conn.logger.Warn("Cannot attempt to reconnect. Another reconnection attempt is already underway.", zap.String("kernel_id", conn.kernelId))
		return false /* reconnection failed */, false /* we did not try to reconnect */
	}

	numTries := 0
	maxNumTries := 5

	conn.logger.Warn("Attempting to reconnect to kernel.", zap.String("kernel_id", conn.kernelId))

	for numTries < maxNumTries {
		err := conn.setupWebsocket(conn.jupyterServerAddress)
		if err != nil {
			if errors.Is(err, ErrNetworkIssue) && (numTries+1) <= maxNumTries {
				numTries += 1
				sleepInterval := time.Second * time.Duration(2*numTries)
				conn.logger.Error("Network error encountered while trying to reconnect to kernel.", zap.String("kernel_id", conn.kernelId), zap.Error(err), zap.Duration("next-sleep-interval", sleepInterval))
				originalStatus := conn.connectionStatus
				err = conn.updateConnectionStatus(KernelDisconnected)
				if err != nil {
					conn.logger.Error("Failed to set kernel connection status.",
						zap.String("initial_status", originalStatus.String()),
						zap.String("target_status", KernelDead.String()),
						zap.String("kernel_id", conn.kernelId))
				}

				time.Sleep(sleepInterval)
				continue
			}

			conn.logger.Error("Connection to kernel is dead.", zap.String("kernel_id", conn.kernelId), zap.Error(err))
			originalStatus := conn.connectionStatus
			err = conn.updateConnectionStatus(KernelDead)
			if err != nil {
				conn.logger.Error("Failed to set kernel connection status.",
					zap.String("initial_status", originalStatus.String()),
					zap.String("target_status", KernelDead.String()),
					zap.String("kernel_id", conn.kernelId))
			}
			return false /* reconnection failed */, true /* we did try to reconnect */
		} else {
			return true /* reconnection succeeded */, true /* we did try to reconnect */
		}
	}

	if !conn.reconnectionInProgress.CompareAndSwap(1, 0) {
		panic("CompareAndSwap should've swapped.")
	}

	return false /* reconnection succeeded */, true /* we did try to reconnect */
}

func (conn *BasicKernelConnection) getKernelModel() (*jupyterKernel, error) {
	conn.logger.Debug("Retrieving kernel model via HTTP Rest API.", zap.String("kernel_id", conn.kernelId))

	address := path.Join(conn.jupyterServerAddress, fmt.Sprintf("/api/kernels/%s", conn.kernelId))
	endpoint := fmt.Sprintf("http://%s", address)
	req, err := http.NewRequest(http.MethodGet, endpoint, nil)
	if err != nil {
		conn.logger.Error("Error encountered while creating HTTP request to get model for kernel.", zap.String("kernel_id", conn.kernelId), zap.String("endpoint", endpoint), zap.Error(err))
		conn.tryCallOnError(err)
		return nil, err
	}

	client := &http.Client{}
	resp, err := client.Do(req)
	if err != nil {
		conn.logger.Error("Received error while requesting model for kernel.", zap.String("kernel_id", conn.kernelId), zap.String("endpoint", endpoint), zap.Error(err))
		conn.tryCallOnError(err)
		return nil, err
	}

	if resp.StatusCode == http.StatusNotFound {
		conn.logger.Error("Received HTTP 404 when retrieving model for kernel.", zap.String("kernel_id", conn.kernelId))
		conn.tryCallOnError(ErrKernelNotFound)
		return nil, ErrKernelNotFound
	} else if resp.StatusCode == http.StatusServiceUnavailable /* 503 */ || resp.StatusCode == http.StatusFailedDependency /* 424 */ {
		// Network errors. We should retry.
		associatedMessage, _ := io.ReadAll(resp.Body)
		conn.logger.Warn("Network error encountered while retrieving kernel model.",
			zap.Int("status-code", resp.StatusCode), zap.String("status", resp.Status), zap.String("message", string(associatedMessage)))
		return nil, ErrNetworkIssue
	} else if resp.StatusCode != http.StatusOK {
		conn.logger.Error("Kernel died unexpectedly.", zap.String("kernel_id", conn.kernelId),
			zap.Int("http-status-code", resp.StatusCode), zap.String("http-status", resp.Status))
		originalStatus := conn.connectionStatus
		err = conn.updateConnectionStatus(KernelDead)
		if err != nil {
			conn.logger.Error("Failed to set kernel connection status.",
				zap.String("initial_status", originalStatus.String()),
				zap.String("target_status", KernelDead.String()),
				zap.String("kernel_id", conn.kernelId))
			conn.tryCallOnError(err)
		}

		err = fmt.Errorf("ErrUnexpectedFailure %w : HTTP %d -- %s", ErrUnexpectedFailure, resp.StatusCode, resp.Status)
		conn.tryCallOnError(err)
		return nil, err
	}

	defer func() {
		err = resp.Body.Close()
		if err != nil {
			conn.logger.Warn("Error while attempting to close body of HTTP response (for getting a kernel's model).",
				zap.String("kernel_id", conn.kernelId),
				zap.Int("resp_status_code", resp.StatusCode),
				zap.String("resp_status", resp.Status),
				zap.Error(err))
		}
	}()

	body, _ := io.ReadAll(resp.Body)
	var model *jupyterKernel

	err = json.Unmarshal(body, &model)
	if err != nil {
		conn.logger.Error("Failed to unmarshal JSON response when requesting model for new kernel.", zap.String("kernel_id", conn.kernelId), zap.String("endpoint", endpoint), zap.Error(err))
		return nil, err
	}

	conn.logger.Debug("Successfully retrieved model for kernel.", zap.String("model", model.String()))
	return model, nil
}

func (conn *BasicKernelConnection) sendMessage(message KernelMessage) error {
	if conn.connectionStatus == KernelDead {
		conn.logger.Error("Cannot send message. Kernel is dead.", zap.String("kernel_id", conn.kernelId))
		return fmt.Errorf("%w: no connection to kernel \"%s\"", ErrKernelIsDead, conn.kernelId)
	}

	if conn.connectionStatus == KernelConnected {
		conn.sugaredLogger.Debugf("Writing %s message (ID=%s) of type '%s' now to kernel %s.", message.GetChannel(), message.GetHeader().MessageId, message.GetHeader().MessageType, conn.kernelId)
		conn.wlock.Lock()

		message.AddMetadata("sent_at_unix_micro", time.Now().UnixMicro())

		err := conn.webSocket.WriteJSON(message)
		conn.wlock.Unlock()
		if err != nil {
			conn.sugaredLogger.Errorf("Error while writing %s message (ID=%s) of type '%s' now to kernel %s. Error: %v", message.GetChannel(), message.GetHeader().MessageId, message.GetHeader().MessageType, conn.kernelId, zap.Error(err))
			return err
		}
	} else {
		conn.sugaredLogger.Errorf("Could not send %s message (ID=%s) of type '%s' now to kernel %s. Kernel is not connected.", message.GetChannel(), message.GetHeader().MessageId, message.GetHeader().MessageType, conn.kernelId)
		return fmt.Errorf("%w: no connection to kernel \"%s\"", ErrNotConnected, conn.kernelId)
	}

	conn.sugaredLogger.Debugf("Successfully sent %s message %s of type %s to kernel %s.", message.GetChannel(), message.GetHeader().MessageId, message.GetHeader().MessageType, conn.kernelId)
	return nil
}
