Feed m Birds with One Scone: Accelerating Multi-task Gradient Balancing via Bi-level Optimization

This paper introduces MARIGOLD, a unified bi-level optimization framework that leverages zeroth-order methods to efficiently solve multi-task learning problems by dynamically balancing task gradients without requiring access to all task gradients, thereby overcoming the computational inefficiency of existing MGDA-type approaches.

Xuxing Chen, Yun He, Jiayi Xu, Minhui Huang, Xiaoyi Liu, Boyang Liu, Fei Tian, Xiaohan Wei, Rong Jin, Sem Park, Bo Long, Xue Feng

Published Tue, 10 Ma
📖 5 min read🧠 Deep dive

Here is an explanation of the paper "Feed m Birds with One Scone" (MARIGOLD) using simple language and creative analogies.

The Big Picture: The "Too Many Cooks" Problem

Imagine you are a chef (the AI model) trying to cook a single meal that satisfies five different guests (the tasks).

  • Guest A wants it spicy.
  • Guest B wants it sweet.
  • Guest C wants it salty.
  • Guest D wants it bland.
  • Guest E wants it crunchy.

If you just follow Guest A's instructions, you ruin the meal for Guest B. If you try to please everyone at once without a plan, you end up with a flavorless, mushy disaster. This is the core problem of Multi-Task Learning (MTL): how do you train one AI to do many different things well without the instructions for one thing messing up the others?

The Old Way: The "All-Hands Meeting" (MGDA)

In the past, the best way to solve this was a method called MGDA (Multiple Gradient Descent Algorithm).

Imagine that to decide what to cook next, the chef calls a meeting with five sous-chefs (the gradients). Each sous-chef holds a clipboard with a detailed report on exactly how the current dish is failing their specific guest.

  • The chef has to read all five reports, compare them, calculate a complex compromise, and then decide on the next step.

The Problem: This is incredibly slow. If you have 100 guests (tasks), the chef has to read 100 reports before taking a single step. In the world of AI, this means the computer has to do a massive amount of math for every single update, making training take forever and requiring huge amounts of memory.

The New Solution: MARIGOLD (The "One Scone" Strategy)

The authors of this paper propose a new method called MARIGOLD. Their title, "Feed m Birds with One Scone," is a play on the phrase "kill two birds with one stone." They want to feed many birds (tasks) with just one scone (computation).

Here is how MARIGOLD works, using a simple analogy:

1. The Two-Level Game (Bi-Level Optimization)

Instead of treating the "cooking" and the "planning" as one giant, messy job, MARIGOLD splits them into two levels:

  • The Lower Level (The Cooking): The chef actually cooks the dish (updates the model) based on a current recipe.
  • The Upper Level (The Planning): A manager watches the cooking and asks, "Is this recipe making everyone happy? If not, how should we tweak the recipe weights?"

In the old days, the manager had to wait for the chef to finish cooking, then read all the sous-chefs' reports to adjust the recipe. MARIGOLD makes this a continuous loop where the manager and chef talk constantly.

2. The Magic Trick: "Zeroth-Order" (Feeling the Heat)

This is the most important part. The old methods required the manager to read the exact math reports from every single sous-chef (calculating all gradients).

MARIGOLD uses a trick called Zeroth-Order Optimization. Instead of reading the reports, the manager just feels the result.

  • The Analogy: Imagine the chef is cooking a soup. Instead of asking 5 people to taste it and write down exactly how much salt is needed, the manager just adds a tiny pinch of salt, tastes the soup, and asks, "Is it better or worse?"
  • If it's better, keep going that way. If it's worse, go the other way.

The manager doesn't need to know the exact chemical composition of the soup (the complex gradients of every task). They just need to know if the overall situation improved or got worse. This allows the computer to skip the heavy math of reading 100 reports and just take one quick "taste" (a single backward pass).

Why is this a Big Deal?

  1. Speed: The old method was like reading a 100-page book before making a decision. MARIGOLD is like glancing at the cover. It reduces the computing work from being proportional to the number of tasks (100x work) to being proportional to just the size of the model (1x work).
  2. Flexibility: The old methods were picky about how you cooked (they only worked with specific types of math updates). MARIGOLD works with any "chef" (optimizer), including the popular Adam optimizer used in most modern AI.
  3. Real-World Results: The authors tested this on:
    • Public Datasets: Like teaching an AI to recognize objects and depth in images at the same time. MARIGOLD was faster and often more accurate than the old methods.
    • Industrial Data: They tested it on a massive Meta advertising system. Even with millions of users and complex goals, MARIGOLD improved the system's ability to predict clicks and conversions better than the standard "equal weight" approach.

Summary

The Problem: Training AI to do many things at once is slow because it has to check every single task's progress individually, like a teacher grading 100 essays one by one before moving to the next class.

The Solution (MARIGOLD): Instead of grading every essay individually, the teacher takes a quick "sample" of the class's overall performance to decide how to adjust the lesson plan.

The Result: You get the same (or better) quality of learning, but you do it 10 to 100 times faster and with much less computer memory. It's the difference between a slow, heavy truck and a nimble sports car that can carry the same load.