13.3 Upcalls and Callbacks

Master creating upcalls for passing Java functions as callbacks to native code, enabling bidirectional communication.

Understanding Upcalls

Upcall Basics:

import java.lang.foreign.*;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;

public class UpcallBasics {
    /**
     * Simple Java callback method
     */
    public static int simpleCallback(int value) {
        System.out.println("Callback invoked with: " + value);
        return value * 2;
    }

    public void demonstrateUpcall() throws Throwable {
        try (Arena arena = Arena.ofConfined()) {
            Linker linker = Linker.nativeLinker();

            // 1. Get Java method handle
            MethodHandles.Lookup lookup = MethodHandles.lookup();
            MethodHandle callback = lookup.findStatic(
                UpcallBasics.class,
                "simpleCallback",
                MethodType.methodType(int.class, int.class)
            );

            // 2. Define function descriptor (matches native expectation)
            // Native: int (*callback)(int)
            FunctionDescriptor callbackDesc = FunctionDescriptor.of(
                ValueLayout.JAVA_INT,    // Return type
                ValueLayout.JAVA_INT     // Parameter type
            );

            // 3. Create upcall stub
            MemorySegment upcallStub = linker.upcallStub(
                callback,
                callbackDesc,
                arena
            );

            System.out.println("Upcall stub address: 0x" + 
                Long.toHexString(upcallStub.address()));

            // 4. Pass stub to native code (demonstrated below)
            // Native code can now call this address
        }
    }
}

QSort Callback Example

Implementing Comparator for qsort:

public class QSortExample {
    private static final Linker LINKER = Linker.nativeLinker();
    private static final SymbolLookup STDLIB = LINKER.defaultLookup();

    // Comparator function signature
    // C: int compare(const void *a, const void *b);
    private static final FunctionDescriptor COMPARE_DESC = FunctionDescriptor.of(
        ValueLayout.JAVA_INT,
        ValueLayout.ADDRESS,
        ValueLayout.ADDRESS
    );

    // QSort signature
    // C: void qsort(void *base, size_t nmemb, size_t size,
    //               int (*compar)(const void *, const void *));
    private static final FunctionDescriptor QSORT_DESC = FunctionDescriptor.ofVoid(
        ValueLayout.ADDRESS,      // base
        ValueLayout.JAVA_LONG,    // nmemb
        ValueLayout.JAVA_LONG,    // size
        ValueLayout.ADDRESS       // compar (function pointer)
    );

    /**
     * Integer comparator (ascending)
     */
    public static int compareInts(MemorySegment a, MemorySegment b) {
        int valA = a.get(ValueLayout.JAVA_INT, 0);
        int valB = b.get(ValueLayout.JAVA_INT, 0);
        return Integer.compare(valA, valB);
    }

    /**
     * Sort array using qsort
     */
    public void sortArray(int[] array) throws Throwable {
        try (Arena arena = Arena.ofConfined()) {
            // Get qsort function
            MethodHandle qsort = LINKER.downcallHandle(
                STDLIB.find("qsort").orElseThrow(),
                QSORT_DESC
            );

            // Create comparator upcall
            MethodHandle compareHandle = MethodHandles.lookup().findStatic(
                QSortExample.class,
                "compareInts",
                MethodType.methodType(
                    int.class,
                    MemorySegment.class,
                    MemorySegment.class
                )
            );

            MemorySegment comparator = LINKER.upcallStub(
                compareHandle,
                COMPARE_DESC,
                arena
            );

            // Copy array to native memory
            MemorySegment nativeArray = arena.allocateArray(
                ValueLayout.JAVA_INT,
                array
            );

            // Call qsort
            qsort.invoke(
                nativeArray,
                (long) array.length,
                4L,  // sizeof(int)
                comparator
            );

            // Copy sorted array back
            MemorySegment.copy(
                nativeArray, ValueLayout.JAVA_INT, 0,
                array, 0,
                array.length
            );
        }
    }

    /**
     * Descending order comparator
     */
    public static int compareIntsDescending(MemorySegment a, MemorySegment b) {
        int valA = a.get(ValueLayout.JAVA_INT, 0);
        int valB = b.get(ValueLayout.JAVA_INT, 0);
        return Integer.compare(valB, valA);  // Reversed
    }

    public void sortDescending(int[] array) throws Throwable {
        try (Arena arena = Arena.ofConfined()) {
            MethodHandle qsort = LINKER.downcallHandle(
                STDLIB.find("qsort").orElseThrow(),
                QSORT_DESC
            );

            MethodHandle compareHandle = MethodHandles.lookup().findStatic(
                QSortExample.class,
                "compareIntsDescending",
                MethodType.methodType(
                    int.class,
                    MemorySegment.class,
                    MemorySegment.class
                )
            );

            MemorySegment comparator = LINKER.upcallStub(
                compareHandle,
                COMPARE_DESC,
                arena
            );

            MemorySegment nativeArray = arena.allocateArray(
                ValueLayout.JAVA_INT,
                array
            );

            qsort.invoke(nativeArray, (long) array.length, 4L, comparator);

            MemorySegment.copy(
                nativeArray, ValueLayout.JAVA_INT, 0,
                array, 0,
                array.length
            );
        }
    }

    // Example usage
    public static void main(String[] args) throws Throwable {
        QSortExample example = new QSortExample();

        int[] numbers = {5, 2, 8, 1, 9, 3, 7, 4, 6};

        System.out.println("Original: " + java.util.Arrays.toString(numbers));

        example.sortArray(numbers);
        System.out.println("Ascending: " + java.util.Arrays.toString(numbers));

        example.sortDescending(numbers);
        System.out.println("Descending: " + java.util.Arrays.toString(numbers));
    }
}

Signal Handler Registration

Registering Signal Handlers:

public class SignalHandlerExample {
    private static final Linker LINKER = Linker.nativeLinker();
    private static final SymbolLookup STDLIB = LINKER.defaultLookup();

    // Signal handler signature
    // C: void handler(int signum);
    private static final FunctionDescriptor HANDLER_DESC = FunctionDescriptor.ofVoid(
        ValueLayout.JAVA_INT
    );

    // signal() signature
    // C: void (*signal(int signum, void (*handler)(int)))(int);
    private static final FunctionDescriptor SIGNAL_DESC = FunctionDescriptor.of(
        ValueLayout.ADDRESS,      // Return: previous handler
        ValueLayout.JAVA_INT,     // signum
        ValueLayout.ADDRESS       // handler
    );

    /**
     * Signal handler callback
     */
    public static void handleSignal(int signum) {
        System.out.println("Received signal: " + signum);

        // Map signal numbers to names (Unix)
        String name = switch (signum) {
            case 2 -> "SIGINT";
            case 15 -> "SIGTERM";
            case 10 -> "SIGUSR1";
            case 12 -> "SIGUSR2";
            default -> "UNKNOWN";
        };

        System.out.println("Signal name: " + name);
    }

    /**
     * Register handler for signal
     */
    public MemorySegment registerHandler(int signalNumber, Arena arena) throws Throwable {
        // Create handler upcall
        MethodHandle handler = MethodHandles.lookup().findStatic(
            SignalHandlerExample.class,
            "handleSignal",
            MethodType.methodType(void.class, int.class)
        );

        MemorySegment handlerStub = LINKER.upcallStub(
            handler,
            HANDLER_DESC,
            arena
        );

        // Get signal() function
        MethodHandle signal = LINKER.downcallHandle(
            STDLIB.find("signal").orElseThrow(),
            SIGNAL_DESC
        );

        // Register handler
        MemorySegment previousHandler = (MemorySegment) signal.invoke(
            signalNumber,
            handlerStub
        );

        System.out.println("Registered handler for signal " + signalNumber);

        return previousHandler;
    }
}

Callback State Management

Capturing State in Callbacks:

public class StatefulCallback {
    /**
     * Callback with captured state
     */
    public static class Counter {
        private int count = 0;

        public int increment(int value) {
            count++;
            System.out.println("Callback #" + count + " with value: " + value);
            return count;
        }
    }

    public void demonstrateStatefulCallback() throws Throwable {
        try (Arena arena = Arena.ofConfined()) {
            Linker linker = Linker.nativeLinker();
            Counter counter = new Counter();

            // Create method handle that captures counter instance
            MethodHandle callback = MethodHandles.lookup().findVirtual(
                Counter.class,
                "increment",
                MethodType.methodType(int.class, int.class)
            ).bindTo(counter);  // Bind instance

            FunctionDescriptor desc = FunctionDescriptor.of(
                ValueLayout.JAVA_INT,
                ValueLayout.JAVA_INT
            );

            MemorySegment upcallStub = linker.upcallStub(callback, desc, arena);

            // Pass to native code...
            // Each call will update counter's state
        }
    }
}

Thread Safety Considerations

Thread-Safe Callbacks:

import java.util.concurrent.atomic.AtomicInteger;

public class ThreadSafeCallback {
    private static final AtomicInteger callCount = new AtomicInteger(0);

    /**
     * Thread-safe callback
     */
    public static int threadSafeCallback(int value) {
        int count = callCount.incrementAndGet();
        System.out.println("[Thread " + Thread.currentThread().getId() + 
                          "] Callback #" + count + " with value: " + value);
        return value + count;
    }

    /**
     * Create callback that can be safely called from multiple threads
     */
    public MemorySegment createThreadSafeCallback(Arena arena) throws Throwable {
        Linker linker = Linker.nativeLinker();

        MethodHandle callback = MethodHandles.lookup().findStatic(
            ThreadSafeCallback.class,
            "threadSafeCallback",
            MethodType.methodType(int.class, int.class)
        );

        FunctionDescriptor desc = FunctionDescriptor.of(
            ValueLayout.JAVA_INT,
            ValueLayout.JAVA_INT
        );

        return linker.upcallStub(callback, desc, arena);
    }
}

Error Handling in Upcalls

Safe Error Handling:

public class SafeUpcalls {
    /**
     * Callback with error handling
     */
    public static int safeCallback(int value) {
        try {
            // Potentially failing operation
            if (value < 0) {
                throw new IllegalArgumentException("Negative value: " + value);
            }

            return processValue(value);

        } catch (Exception e) {
            // Log error but don't let it escape to native code
            System.err.println("Callback error: " + e.getMessage());
            e.printStackTrace();

            // Return error code
            return -1;
        }
    }

    private static int processValue(int value) {
        // Processing logic
        return value * 2;
    }

    /**
     * Callback that never throws
     */
    public static int noThrowCallback(int value) {
        try {
            return riskyOperation(value);
        } catch (Throwable t) {
            System.err.println("Unexpected error in callback: " + t);
            return 0;  // Safe default
        }
    }

    private static int riskyOperation(int value) {
        return 100 / value;  // Could divide by zero
    }
}

Real-World Example: Event System

import java.lang.foreign.*;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;

public class EventSystem {
    private final Linker linker = Linker.nativeLinker();
    private final Arena arena = Arena.ofShared();
    private final List<EventListener> listeners = new CopyOnWriteArrayList<>();
    private final Map<String, MemorySegment> registeredCallbacks = new HashMap<>();

    public interface EventListener {
        void onEvent(int eventType, String eventData);
    }

    // Native event handler signature
    // C: void event_handler(int type, const char *data);
    private static final FunctionDescriptor EVENT_HANDLER_DESC = FunctionDescriptor.ofVoid(
        ValueLayout.JAVA_INT,     // event type
        ValueLayout.ADDRESS       // event data (C string)
    );

    /**
     * Event handler callback
     */
    public void handleEvent(int eventType, MemorySegment eventDataPtr) {
        String eventData = null;

        if (eventDataPtr.address() != 0) {
            eventData = eventDataPtr.reinterpret(Long.MAX_VALUE).getUtf8String(0);
        }

        // Notify all listeners
        for (EventListener listener : listeners) {
            try {
                listener.onEvent(eventType, eventData);
            } catch (Exception e) {
                System.err.println("Listener error: " + e.getMessage());
            }
        }
    }

    /**
     * Register event listener
     */
    public void addEventListener(EventListener listener) {
        listeners.add(listener);
    }

    /**
     * Remove event listener
     */
    public void removeEventListener(EventListener listener) {
        listeners.remove(listener);
    }

    /**
     * Register callback with native library
     */
    public void registerWithNativeLibrary(String libraryFunction) throws Throwable {
        if (registeredCallbacks.containsKey(libraryFunction)) {
            throw new IllegalStateException("Already registered: " + libraryFunction);
        }

        // Create upcall stub
        MethodHandle callback = MethodHandles.lookup().findVirtual(
            EventSystem.class,
            "handleEvent",
            MethodType.methodType(
                void.class,
                int.class,
                MemorySegment.class
            )
        ).bindTo(this);

        MemorySegment callbackStub = linker.upcallStub(
            callback,
            EVENT_HANDLER_DESC,
            arena
        );

        // Call native registration function
        // C: void register_event_handler(void (*handler)(int, const char*));
        SymbolLookup lookup = SymbolLookup.loaderLookup();
        MethodHandle registerFunc = linker.downcallHandle(
            lookup.find(libraryFunction).orElseThrow(),
            FunctionDescriptor.ofVoid(ValueLayout.ADDRESS)
        );

        registerFunc.invoke(callbackStub);

        registeredCallbacks.put(libraryFunction, callbackStub);
        System.out.println("Registered event handler with: " + libraryFunction);
    }

    /**
     * Close and cleanup
     */
    public void close() {
        listeners.clear();
        registeredCallbacks.clear();
        arena.close();
    }

    // Example usage
    public static void main(String[] args) {
        EventSystem eventSystem = new EventSystem();

        // Add listeners
        eventSystem.addEventListener((type, data) -> {
            System.out.println("Listener 1 - Type: " + type + ", Data: " + data);
        });

        eventSystem.addEventListener((type, data) -> {
            System.out.println("Listener 2 - Type: " + type + ", Data: " + data);
        });

        // Simulate native event
        try (Arena testArena = Arena.ofConfined()) {
            MemorySegment testData = testArena.allocateUtf8String("Test Event");
            eventSystem.handleEvent(100, testData);
        }

        eventSystem.close();
    }
}

Callback Patterns

Pattern 1: One-Time Callback:

public class OneTimeCallback {
    /**
     * Callback that executes once and unregisters
     */
    public static class OnceHandler {
        private boolean executed = false;

        public synchronized int handle(int value) {
            if (executed) {
                return -1;  // Already executed
            }

            executed = true;
            System.out.println("Executing once with: " + value);
            return value;
        }
    }
}

Pattern 2: Timeout Callback:

public class TimeoutCallback {
    /**
     * Callback with timeout
     */
    public static class TimedHandler {
        private final long expirationTime;

        public TimedHandler(long timeoutMillis) {
            this.expirationTime = System.currentTimeMillis() + timeoutMillis;
        }

        public int handle(int value) {
            if (System.currentTimeMillis() > expirationTime) {
                System.out.println("Callback expired");
                return -1;
            }

            System.out.println("Processing: " + value);
            return value;
        }
    }
}

Pattern 3: Filtering Callback:

public class FilteringCallback {
    /**
     * Callback that filters events
     */
    public static class FilterHandler {
        private final java.util.function.Predicate<Integer> filter;

        public FilterHandler(java.util.function.Predicate<Integer> filter) {
            this.filter = filter;
        }

        public int handle(int value) {
            if (!filter.test(value)) {
                return 0;  // Filtered out
            }

            System.out.println("Accepted: " + value);
            return value;
        }
    }
}

Best Practices

1. Use Appropriate Arena Lifetime:

// Good - arena matches callback lifetime
try (Arena arena = Arena.ofShared()) {
    MemorySegment callback = linker.upcallStub(handle, desc, arena);
    registerCallback(callback);
    // Use callback...
} // Callback invalidated when arena closes

// Bad - arena closes before callback done
try (Arena arena = Arena.ofConfined()) {
    MemorySegment callback = linker.upcallStub(handle, desc, arena);
    registerCallback(callback);
} // DANGER: Callback still registered but invalid!

2. Never Let Exceptions Escape:

// Good - catch all exceptions
public static int safeCallback(int value) {
    try {
        return processValue(value);
    } catch (Throwable t) {
        System.err.println("Error: " + t);
        return -1;
    }
}

// Bad - exception could propagate to native code
public static int unsafeCallback(int value) {
    return processValue(value);  // May throw!
}

3. Document Thread Safety:

/**
 * Thread-safe callback. Can be safely invoked from
 * multiple native threads concurrently.
 */
public static synchronized int threadSafeCallback(int value) {
    // Implementation
    return value;
}

4. Clean Up Resources:

// Good - cleanup in finally or try-with-resources
try (Arena arena = Arena.ofShared()) {
    MemorySegment callback = createCallback(arena);
    useCallback(callback);
} // Automatic cleanup

// Bad - no cleanup
Arena arena = Arena.ofShared();
MemorySegment callback = createCallback(arena);
// Memory leak!

5. Test Callbacks Thoroughly:

@Test
public void testCallback() throws Throwable {
    AtomicInteger callCount = new AtomicInteger(0);

    MethodHandle callback = MethodHandles.lookup().findStatic(
        getClass(),
        "testMethod",
        MethodType.methodType(void.class, int.class)
    );

    // Create and test callback
    try (Arena arena = Arena.ofConfined()) {
        MemorySegment stub = linker.upcallStub(callback, desc, arena);
        // Invoke and verify
    }
}

These patterns enable safe, efficient bidirectional communication between Java and native code.