Waiting on results of Replicate predictions using webhooks

almirsarajcic

almirsarajcic

Created 1 month ago

Have you spent days figuring out how to package a regular Phoenix app with CUDA drivers and keep the Docker image reasonably sized? No? So it was just me?

Anyway, I gave up on that (temporarily 🤕) and decided to invest more time in the core business logic and less into infrastructure. Replicate seemed like the perfect fit. So, on Replicate, you choose an existing model you’ll use (or package your own), send a request to create a prediction, and wait.

Now, I’m not gonna poll the Replicate API for updates like some peasant. I can utilize the power of Elixir to get notified by Replicate when the prediction is completed by receiving a response on a webhook endpoint and then passing the message to the appropriate process.

# lib/my_app_web/router.ex
scope "/webhook", MyAppWeb do
  pipe_through :webhook

  post "/replicate", WebhookController, :replicate
end

# lib/my_app_web/controllers/webhook_controller.ex
def replicate(conn, payload) do
  WebhookHandler.handle_webhook(payload)

  conn
  |> put_status(:ok)
  |> json(%{status: "ok"})
end

# lib/my_app/webhook_handler.ex
def handle_webhook(%{"id" => prediction_id, "status" => "succeeded", "output" => output}) do
  case Registry.lookup(MyApp.PredictionRegistry, prediction_id) do
    [{pid, _}] ->
      send(pid, {:prediction_completed, prediction_id, output})
      unregister_prediction(prediction_id)

    [] ->
      Logger.warning("No process waiting for prediction #{prediction_id}")
  end
end

def handle_webhook(%{"id" => prediction_id, "status" => status, "error" => error})
    when status in ["failed", "canceled"] do
  case Registry.lookup(MyApp.PredictionRegistry, prediction_id) do
    [{pid, _}] ->
      send(pid, {:prediction_failed, prediction_id, error})
      unregister_prediction(prediction_id)

    [] ->
      Logger.warning("No process waiting for prediction #{prediction_id}")
  end
end

def handle_webhook(_payload), do: :ok

We use Registry to store the PID of the process that has sent a request to Replicate under the prediction ID used as a key. That way, when we receive a webhook with the prediction ID, we can match it with the appropriate process.

So, that was the receiving part. Now, for the part that waits for a response after sending a request to Replicate:

# lib/application.ex
{Registry, keys: :unique, name: MyApp.PredictionRegistry},

# lib/my_app/replicate_client.ex
defp wait_for_webhook(prediction_id, timeout) do
  WebhookHandler.register_for_prediction(prediction_id, self())

  receive do
    {:prediction_completed, ^prediction_id, output} ->
      WebhookHandler.unregister_prediction(prediction_id)
      {:ok, handle_output(output)}

    {:prediction_failed, ^prediction_id, error} ->
      Logger.error("Prediction failed: #{error}")
      WebhookHandler.unregister_prediction(prediction_id)
      {:error, "Prediction failed: #{error}"}
  after
    timeout ->
      WebhookHandler.unregister_prediction(prediction_id)
      {:error, "Prediction timed out"}
  end
end

That’s pretty much it. We don’t send unnecessary requests to Replicate API and handle only events we’re interested in.