feat: support for cancelling generations#1124
Conversation
d69eafa to
de7bad2
Compare
|
I was just checking this out, and it looks really promising! I made some edits to the design on my end, compartmentalizing the signal handler into it's own object file so it could be reused between the cli and server. I also just had a successful test against sd-server receiving a cancel from a client hangup! I whipped up a quick patch of the changes if that's helpful. The only other things I can think of that might be useful would be maybe adding an initializer_list to set what signals get captured, but I think SIG_USR1 was a good default choice. Let me know what you think! sd_cancel.patchdiff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 2dcd1d5..c1ae3b3 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -1,4 +1,8 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
+add_library(signal_handler OBJECT common/signal_handler.cpp)
+target_include_directories(signal_handler PUBLIC ../include)
+
add_subdirectory(cli)
-add_subdirectory(server)
\ No newline at end of file
+add_subdirectory(server)
+
diff --git a/examples/cli/CMakeLists.txt b/examples/cli/CMakeLists.txt
index b30a2e8..f24af51 100644
--- a/examples/cli/CMakeLists.txt
+++ b/examples/cli/CMakeLists.txt
@@ -3,4 +3,5 @@ set(TARGET sd-cli)
add_executable(${TARGET} main.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT})
-target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
\ No newline at end of file
+target_link_libraries(${TARGET} PRIVATE signal_handler)
+target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp
index 503177c..e8e093c 100644
--- a/examples/cli/main.cpp
+++ b/examples/cli/main.cpp
@@ -474,12 +474,8 @@ bool save_results(const SDCliParams& cli_params,
return sucessful_reults != 0;
}
-#if defined(__unix__) || defined(__APPLE__) || defined(_POSIX_VERSION)
-#define SD_ENABLE_SIGNAL_HANDLER
-static void set_signal_cancel_handler(sd_ctx_t* sd_ctx);
-#else
-#define set_signal_cancel_handler(SD_CTX) ((void)SD_CTX)
-#endif
+#include "common/signal_handler.hpp"
+
int main(int argc, const char* argv[]) {
if (argc > 1 && std::string(argv[1]) == "--version") {
@@ -848,58 +844,3 @@ int main(int argc, const char* argv[]) {
return 0;
}
-#ifdef SD_ENABLE_SIGNAL_HANDLER
-
-#include <atomic>
-#include <csignal>
-#include <thread>
-#include <unistd.h>
-
-// this lock is needed to avoid a race condition between
-// free_sd_ctx and a pending sd_cancel_generation call
-std::atomic_flag signal_lock = ATOMIC_FLAG_INIT;
-static int g_sigint_cnt;
-static sd_ctx_t* g_sd_ctx;
-
-static void sig_cancel_handler(int /* signum */)
-{
- if (!signal_lock.test_and_set(std::memory_order_acquire)) {
- if (g_sd_ctx != nullptr) {
- if (g_sigint_cnt == 1) {
- char msg[] = "\ngot cancel signal, cancelling new generations\n";
- write(2, msg, sizeof(msg)-1);
- /* first signal cancels only the remaining latents on a batch */
- sd_cancel_generation(g_sd_ctx, SD_CANCEL_NEW_LATENTS);
- ++g_sigint_cnt;
- } else {
- char msg[] = "\ngot cancel signal, cancelling everything\n";
- write(2, msg, sizeof(msg)-1);
- /* cancels everything */
- sd_cancel_generation(g_sd_ctx, SD_CANCEL_ALL);
- }
- }
- signal_lock.clear(std::memory_order_release);
- }
-}
-
-static void set_signal_cancel_handler(sd_ctx_t* sd_ctx)
-{
- if (g_sigint_cnt == 0) {
- g_sigint_cnt++;
- struct sigaction sa{};
- sa.sa_handler = sig_cancel_handler;
- sa.sa_flags = SA_RESTART;
- sigaction(SIGUSR1, &sa, nullptr);
- }
-
- while (signal_lock.test_and_set(std::memory_order_acquire)) {
- std::this_thread::yield();
- }
-
- g_sd_ctx = sd_ctx;
-
- signal_lock.clear(std::memory_order_release);
-}
-
-#endif
-
diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt
index d191260..bc2d331 100644
--- a/examples/server/CMakeLists.txt
+++ b/examples/server/CMakeLists.txt
@@ -3,4 +3,5 @@ set(TARGET sd-server)
add_executable(${TARGET} main.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT})
-target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
\ No newline at end of file
+target_link_libraries(${TARGET} PRIVATE signal_handler)
+target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
diff --git a/examples/server/main.cpp b/examples/server/main.cpp
index 0fb10c7..d8d27e6 100644
--- a/examples/server/main.cpp
+++ b/examples/server/main.cpp
@@ -7,6 +7,7 @@
#include <mutex>
#include <sstream>
#include <vector>
+#include <future>
#include "httplib.h"
#include "stable-diffusion.h"
@@ -268,6 +269,8 @@ struct LoraEntry {
std::string path;
};
+#include "common/signal_handler.hpp"
+
int main(int argc, const char** argv) {
if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n";
@@ -346,6 +349,8 @@ int main(int argc, const char** argv) {
[&](const LoraEntry& e) { return e.path == path; });
};
+ set_signal_cancel_handler(sd_ctx);
+
httplib::Server svr;
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
@@ -507,11 +512,20 @@ int main(int argc, const char** argv) {
sd_image_t* results = nullptr;
int num_results = 0;
+ std::future<void> ft = std::async(std::launch::async, [&]()
{
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
+ );
+
+ std::future_status ft_status;
+ do {
+ if (!ft.valid()) break;
+ ft_status = ft.wait_for(std::chrono::milliseconds(1000));
+ if (req.is_connection_closed()) std::raise(SIGUSR1);
+ } while (ft_status != std::future_status::ready);
for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr) { |
Co-authored-by: donington <jandastroy@gmail.com>
That was very helpful, thanks! And actually, the signal handler just complicates things for this: we can call I'm even tempted to drop the signal handling stuff from this PR, or maybe move it to a separate one, since the hangup handler ended up simpler and more portable. |
I literally just got back to the computer and was reading your changes, and this is extremely clean now! I guess I didn't take a close enough look to realize that your cancellation was entirely isolated already when I whipped up my example of using futures.
It's true. You could probably simplify it down to using |
Adds an
sd_cancel_generationfunction that can be called asynchronously to interrupt the current generation.The log handling is still a bit rough on the edges, but I wanted to gather more feedback before polishing it. I've included a flag to allow finer control of what to cancel: everything, or keep and decode already-generated latents but cancel the current and next generations. Would an extra "finish the already started latent but cancel the batch" mode be useful? Or should I simplify it instead, keeping just the cancel-everything mode?
The function should be safe to be called from the progress or preview callbacks, a separate thread, or a signal handler.
I've included a Unix signal handler onmain.cppjust to be able to test it: the first Ctrl+C cancels the batch and the current gen, but still finishes the already generated latents, while a second Ctrl+C cancels everything (although it won't interrupt it in the middle of a generation step anymore).Edit: included sd-server support for canceling the current generation if the client disconnects.
fixes #1036