/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * license agreements; and to You under the Apache License, version 2.0:
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * This file is part of the Apache Pekko project, which was derived from Akka.
 */

/*
 * Copyright (C) 2014 - 2016 Softwaremill <https://softwaremill.com>
 * Copyright (C) 2016 - 2020 Lightbend Inc. <https://www.lightbend.com>
 */

package org.apache.pekko.kafka.internal

import org.apache.pekko
import pekko.actor.ActorRef
import pekko.annotation.InternalApi
import pekko.kafka.scaladsl.PartitionAssignmentHandler
import pekko.kafka.javadsl
import pekko.kafka.{ AutoSubscription, RestrictedConsumer, TopicPartitionsAssigned, TopicPartitionsRevoked }
import pekko.stream.stage.AsyncCallback
import pekko.util.ccompat.JavaConverters._
import org.apache.kafka.common.TopicPartition

/**
 * Internal API.
 *
 * Implementations of [[PartitionAssignmentHandler]] for internal use.
 */
@InternalApi
object PartitionAssignmentHelpers {

  @InternalApi
  object EmptyPartitionAssignmentHandler extends PartitionAssignmentHandler {
    override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()

    override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()

    override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()

    override def onStop(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()

    override def toString: String = "EmptyPartitionAssignmentHandler"
  }

  @InternalApi
  final case class WrappedJava(handler: javadsl.PartitionAssignmentHandler) extends PartitionAssignmentHandler {
    override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
      handler.onRevoke(revokedTps.asJava, consumer)

    override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
      handler.onAssign(assignedTps.asJava, consumer)

    override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
      handler.onLost(lostTps.asJava, consumer)

    override def onStop(currentTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
      handler.onStop(currentTps.asJava, consumer)

    override def toString: String = s"WrappedJava($handler)"
  }

  @InternalApi
  final class AsyncCallbacks(subscription: AutoSubscription,
      sourceActor: ActorRef,
      partitionAssignedCB: AsyncCallback[Set[TopicPartition]],
      partitionRevokedCB: AsyncCallback[Set[TopicPartition]])
      extends PartitionAssignmentHandler {

    override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = {
      subscription.rebalanceListener.foreach {
        _.tell(TopicPartitionsRevoked(subscription, revokedTps), sourceActor)
      }
      if (revokedTps.nonEmpty) {
        partitionRevokedCB.invoke(revokedTps)
      }
    }

    override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = {
      subscription.rebalanceListener.foreach {
        _.tell(TopicPartitionsAssigned(subscription, assignedTps), sourceActor)
      }
      if (assignedTps.nonEmpty) {
        partitionAssignedCB.invoke(assignedTps)
      }
    }

    override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
      onRevoke(lostTps, consumer)

    override def onStop(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()

    override def toString: String = s"AsyncCallbacks($subscription, $sourceActor)"
  }

  @InternalApi
  final class Chain(handler1: PartitionAssignmentHandler, handler2: PartitionAssignmentHandler)
      extends PartitionAssignmentHandler {
    override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = {
      handler1.onRevoke(revokedTps, consumer)
      handler2.onRevoke(revokedTps, consumer)
    }

    override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = {
      handler1.onAssign(assignedTps, consumer)
      handler2.onAssign(assignedTps, consumer)
    }

    override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = {
      handler1.onLost(lostTps, consumer)
      handler2.onLost(lostTps, consumer)
    }

    override def onStop(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = {
      handler1.onStop(revokedTps, consumer)
      handler2.onStop(revokedTps, consumer)
    }

    override def toString: String = s"Chain($handler1, $handler2)"
  }

  def chain(handler1: PartitionAssignmentHandler, handler2: PartitionAssignmentHandler): PartitionAssignmentHandler =
    if (handler1 == EmptyPartitionAssignmentHandler) handler2
    else if (handler2 == EmptyPartitionAssignmentHandler) handler1
    else new Chain(handler1, handler2)

}
