Multi-Task Learning with Deep Neural Networks

Abhishek Bais
11 min readNov 26, 2021


Acquring and assimilating knowledge from multiple domains is a central tenet of human intelligence. For example, when a newborn baby learns to walk it acquires and assimilates general motor skills. Later in life, it uses and builds on these skills to perform more complex tasks such as playing soccer.

A baby learns general motor skills while learning to walk, augments, and uses later in life to perform more complex tasks such as playing soccer

In much the same way, Multi-Task Learning reflects the human learning process more closely than the traditional single-task learning process. This medium article walks over the paper “Multi-Task Learning With Deep Neural Networks: A Survey”, by Micheal Crawshaw from the Department of Computer Science, George Mason University, that describes the various facets of Multi-Task Learning.

What is Multi-Task Learning?

Multi-Task Learning (MTL) is a branch of machine learning in which multiple tasks are trained simultaneously, using shared representations to learn the common ideas between a collection of related tasks. These shared representations offer advantages such as improved data efficiency, reduced overfitting, and faster learning speed for downstream tasks, helping alleviate the large-scale data requirements and computational demand of deep learning. However, the simultaneous learning of multiple tasks presents optimization challenges such as identifying which tasks should be learned jointly.

What does the paper focus on?

Traditionally MTL methods have often been partitioned into two groups hard parameter sharing vs. soft parameter sharing. Hard parameter sharing is the practice of sharing model weights between tasks so that each weight is trained to jointly minimize multiple loss functions. Under Soft parameter sharing, while different tasks have individual task-specific models with separate weights, the distance between the model parameters of different tasks is added to the joint objective function.

Expanded Multi-Task Learning Methods

As the nature of multi-task methods grows rapidly, the author feels these two categories alone are not broad enough to accurately describe the entire MTL field. Therefore, in this paper, the author focuses on widening the scope of MTL methods to cover more ground. The class of hard parameter sharing methods is generalized as Multi-Task Learning Architectures, while soft parameter sharing is broadened into Multi-Task Optimization Methods.

In addition, the author also discusses Multi-Task Relationship Learning methods that find relationships between tasks and help choose one or two auxiliary tasks (such as POS tagging, syntactic chunking, and word counting) to help learn the main task (such as named entity recognition and semantic frame detection) or transfer learnings from pre-trained similar tasks to a new task.

What are different Multi-Task Learning Architectures?

Multi-Task Architectures

The authors partition MTL architectures into four groups: Architectures for a particular Task Domain, Multi-Modal Architectures, Learned Architectures, and Conditional Architectures.

Task Domain Architectures

Task Domain Architectures are focused on tasks for a single domain such as Computer Vision or Natural Language Processing (NLP).

Architectures for Computer Vision

In the single-task setting, MTL for computer vision focuses on partitioning the network into task-specific shared components in a way that allows for generalization through sharing and information flow between tasks. For example in their architecture for TCDCN, (Zhang et al., 2014) build a feature extractor from a series of convolutional layers shared between all tasks. The extracted features are used as input to task-specific output heads.

Architecture for TCDCN (Zhang et al., 2014)

Architectures of Natural Language Processing (NLP)

NLP lends itself well to MTL, due to the abundance of related questions one can ask about a given piece of text and the task-agnostic representations often used in modern NLP techniques. Feed-forward (non-attention-based) based network architectures such as by (Liu et al., 2015) are an example. They have a strong structural resemblance to the shared architectures of computer vision discussed in the previous section. Here, the input is converted to a bag-of-words representation and hashed into letter 3-grams, followed by a shared linear transformation and nonlinear activation function. This shared representation is passed to task-specific heads to compute final outputs for each task.

The network architecture of (Liu et al., 2015)

Multi-Modal Architectures

While Task-Domain Architectures are focused on data for a single domain, Multi-Modal Architectures handle tasks using data from multiple domains, usually a combination of visual and linguistic data. Here representations are shared across tasks and modes, providing another layer of abstraction through which learned representations must generalize. For example in the OmniNet architecture proposed by (Pramanik et al., 2019), each modality has a separate network to handle inputs. The aggregated outputs are processed by an encoder-decoder called the Central Neural Processor, the output of which is passed to several task-specific output heads.

OmniNet architecture proposed in (Pramanik et al., 2019)

Learned Architectures

An alternate approach to MTL architecture design is to learn architecture as well as the weights of the resulting model. Two popular methods to do this are Branched Sharing and Modular Sharing.

Branched Sharing

Branched sharing methods are a coarse-grained way to share parameters between tasks. Once the computation graphs for two tasks differ, they never rejoin. For example in the Branched Sharing Architecture proposed by (Lu et al., 2017), at the beginning of training, each task shares all layers of the network. As training goes on, less related tasks branch into clusters, so that only highly related tasks share as many parameters.

Branched sharing architecture proposed in (Lu et al., 2017)

Modular Sharing

Modular sharing represents a more fine-grained approach, in which a set of neural network modules is shared between tasks, where the architecture for each task is made by a task-specific combination of some or all of the modules. For example, AdaShare is a modular MTL algorithm proposed by (Sun et al., 2019) in which each task architecture is comprised of a sequence of network layers. Each layer in the shared set is either included or omitted from the network for each task.

AdaShare is a modular sharing scheme proposed by (Sun et al., 2019)

Conditional Architectures

Conditional Architectures select parts of a neural network for execution depending on the input to the network. These architectures are dynamic between inputs as well as between tasks though the components of these dynamically instantiated architectures are shared, which encourages these components to be generalizable between various inputs and tasks. For example in the Neural Module Network execution model proposed by (Andreas et al., 2016), the semantic structure of a given question is used to dynamically instantiate a network made of modules that correspond to the elements of the question.

Example Neural Module Network execution (Andreas et al., 2016)

What are different Multi-Task Optimization methods?

Multi-Task Optimization Methods

MTL Optimization Methods are a broader version of soft parameter sharing. Soft parameter sharing is a way to regularize model parameters by penalizing the distance from model parameters to corresponding parameters of a model for a different, but related task. Some common MTL Optimization methods include Loss weighting, Gradient Modulation, Task Scheduling, and Knowledge Distillation.

Loss Weighting

A common approach to ease multi-task optimization is to balance the individual loss functions for different tasks. When a model is to be trained on more than one task, the various task-specific loss functions must be combined into a single aggregated loss function which the model is trained to minimize.

One approach to learning loss weights is called Loss Weighting by Uncertainty. Proposed by (Kendall et al., 2017), this approach treats the multi-task network as a probabilistic model and derives a weighted multi-task loss function by maximizing the likelihood of the ground truth output.

Kendal et. al., 2017

Task Scheduling

Task scheduling is the process of choosing which task or tasks to train on at each training step. Most MTL models make this decision in a very simple way, either training on all tasks at each step or randomly sampling a subset of tasks to train on, with some variation in these simple task schedulers. For example in the Task Scheduler proposed by (Sharma et al., 2017), a meta task-decider is trained to sample tasks with a training signal that encourages tasks with worse relative performance to be chosen more frequently.

Task scheduling visualization from (Sharma et al., 2017)

Gradient Modulation

One of the main challenges in MTL is negative transfer when the joint training of tasks hurts learning instead of helping it. From an optimization perspective, negative transfer manifests as the presence of conflicting task gradients. When two tasks have gradients pointing in opposing directions, following the gradient for one task will decrease the performance of the other task, while following the average of the two gradients means that neither task sees the same improvement it would in a single-task training setting.

If a multi-task model is training on a collection of related tasks, then ideally the gradients from these tasks should point in similar directions. Gradient Adversarial Training (GREAT) proposed by (Sinha et al., 2018) explicitly enforces this condition. Here, an auxiliary network takes a gradient vector for a single task’s loss and tries to classify which task the gradient vector came from. The network gradients are then modulated to minimize the performance of the auxiliary network, to enforce the condition that gradients from different task functions have statistically indistinguishable distributions.

Multi-task GREAT model (Sinha et al., 2018)

Knowledge Distillation

Knowledge Distillation is used to instill a single multi-task “student” network with the knowledge of many individual single-task “teacher” networks.

Most knowledge distillation algorithms have an asymmetric information flow between student and teacher, where information travels from teacher to student, but not the other way around. Interestingly, the performance of the student network has been shown to surpass that of the teacher networks in some domains, making knowledge distillation a desirable method not just for saving memory, but also for increasing performance.

This raises the question “Should teacher networks receive information from the distilled multi-task student network?”

Distral framework for multi-task reinforcement learning by (Teh et al., 2017) provides a setting that accomplishes symmetric information flow between student and teacher. The Distral framework is based on two main ideas: The single-task policies are regularized by minimizing the KL-divergence between single-task policies and the shared multi-task policy as a part of the training objective, and the policies for each task are formed by adding the output of the corresponding single-task policy with the output of the shared multi-task policy. Two architectures from the Distral framework for RL (Teh et al., 2017) are shown in the figure below. On the left is an architecture that employs both of the main ideas. On the right is an architecture that only employs KL-regularization of the single-task policies.

Two architectures from the Distral framework for RL (Teh et al., 2017)

What is Multi-Task Relationship Learning?

Multi-Task Relationship Learning

Multi-Task Relationship Learning is a mechanism to learn explicit representations of tasks or relationships between tasks, such as clustering tasks into groups by similarity and leveraging the learned task relationships to improve learning on the tasks at hand.

Two areas of research in Multi-Task Relationship Learning are in Grouping Tasks, where the goal is to partition a collection of tasks into groups such that simultaneous training of tasks in a group is beneficial, and in learning Transfer Relationships which include methods that attempt to analyze and understand when transferring knowledge from one task to another is beneficial for learning.

Grouping Tasks

Grouping Tasks provides an alternative solution to negative transfer in MTL. In this scheme, should two tasks exhibit negative transfer, it is suggested to separate their learning from the start. However, doing so requires significant computation time for trial and error in training networks jointly for various sets of tasks. A popular scheme to partition a group of tasks into clusters which each exhibit positive transfer between their respective tasks has been proposed by (Standley et al., 2019).

An example partitioning of a group of tasks into clusters with positive transfer (Standley et al., 2019)

Transfer Relationship Learning

Transfer learning already plays an important role in Computer Vision and Natural Language Processing where instead of building models from start, a pre-trained model from a similar task is used on a new task. Along the same lines, the motivation of the Transfer Relationship Learning method is to limit the number of tasks that have access to the full amount of supervised data (these are the source tasks) and to learn the remainder of tasks by transferring from the source tasks, with only a small amount of training data to train the decoder on top of the transferred feature extractor.

For example (Zamir et al., 2018) introduce the Taskonomy dataset with 4 million images labeled for 26 tasks and a computational method to construct a taxonomy of visual tasks based on transfer relationships between tasks. To do this, they propose a single-task network to train each task, then transfer relationships computed by answering questions such as “How well can we perform task i by training a decoder on top of a feature extractor which was trained on task j?”

Task taxonomies for a collection of computer vision tasks as computed in Taskonomy (Zamir et al., 2018)

What are different Multi-Task Benchmarks?

The author also provides insights into commonly used benchmarks used in various domains of MTL, including benchmarks for Computer Vision, Natural Language Processing, Reinforcement Learning, and Multi-Modal problems. Some of these benchmarks are captured below.

Computer Vision Benchmarks

Natural Language Processing Benchmarks

Reinforcement Learning Benchmarks

Multi-Modal Benchmarks


In the paper “Multi-Task Learning With Deep Neural Networks: A Survey”, the author Micheal Crawshaw, presents a review of the field of Multi-Task Learning, covering the three broad directions of Architectures, Optimization Methods, and Task Relationship Learning along with key ideas from research in these areas. The author also provides readers insights into commonly used benchmarks used in various domains of Multi-Task Learning.

In addition, the author highlights that “Despite the progress to develop Multi-Task Learning for deep networks, there is one direction of research that has had less development than others: theory and that is an important area of future research to promote deeper understanding of MTL with deep nueral networks.

The author concludes by reiterating that to build machines that learn as quickly and robustly as humans, developments in Multi-Task Learning is an important step as it focuses on developing artificial intelligence with more human-like qualities.



Abhishek Bais

Seasoned R&D EDA, Data Science Enthusiast, Cultural Explorer