Multi-Task Learning with Deep Neural Networks

A baby learns general motor skills while learning to walk, augments, and uses later in life to perform more complex tasks such as playing soccer

What is Multi-Task Learning?

Multi-Task Learning (MTL) is a branch of machine learning in which multiple tasks are trained simultaneously, using shared representations to learn the common ideas between a collection of related tasks. These shared representations offer advantages such as improved data efficiency, reduced overfitting, and faster learning speed for downstream tasks, helping alleviate the large-scale data requirements and computational demand of deep learning. However, the simultaneous learning of multiple tasks presents optimization challenges such as identifying which tasks should be learned jointly.

What does the paper focus on?

Traditionally MTL methods have often been partitioned into two groups hard parameter sharing vs. soft parameter sharing. Hard parameter sharing is the practice of sharing model weights between tasks so that each weight is trained to jointly minimize multiple loss functions. Under Soft parameter sharing, while different tasks have individual task-specific models with separate weights, the distance between the model parameters of different tasks is added to the joint objective function.

Expanded Multi-Task Learning Methods

What are different Multi-Task Learning Architectures?

Multi-Task Architectures

Task Domain Architectures

Task Domain Architectures are focused on tasks for a single domain such as Computer Vision or Natural Language Processing (NLP).

Architecture for TCDCN (Zhang et al., 2014)
The network architecture of (Liu et al., 2015)

Multi-Modal Architectures

While Task-Domain Architectures are focused on data for a single domain, Multi-Modal Architectures handle tasks using data from multiple domains, usually a combination of visual and linguistic data. Here representations are shared across tasks and modes, providing another layer of abstraction through which learned representations must generalize. For example in the OmniNet architecture proposed by (Pramanik et al., 2019), each modality has a separate network to handle inputs. The aggregated outputs are processed by an encoder-decoder called the Central Neural Processor, the output of which is passed to several task-specific output heads.

OmniNet architecture proposed in (Pramanik et al., 2019)

Learned Architectures

An alternate approach to MTL architecture design is to learn architecture as well as the weights of the resulting model. Two popular methods to do this are Branched Sharing and Modular Sharing.

Branched sharing architecture proposed in (Lu et al., 2017)
AdaShare is a modular sharing scheme proposed by (Sun et al., 2019)

Conditional Architectures

Conditional Architectures select parts of a neural network for execution depending on the input to the network. These architectures are dynamic between inputs as well as between tasks though the components of these dynamically instantiated architectures are shared, which encourages these components to be generalizable between various inputs and tasks. For example in the Neural Module Network execution model proposed by (Andreas et al., 2016), the semantic structure of a given question is used to dynamically instantiate a network made of modules that correspond to the elements of the question.

Example Neural Module Network execution (Andreas et al., 2016)

What are different Multi-Task Optimization methods?

Multi-Task Optimization Methods

Loss Weighting

A common approach to ease multi-task optimization is to balance the individual loss functions for different tasks. When a model is to be trained on more than one task, the various task-specific loss functions must be combined into a single aggregated loss function which the model is trained to minimize.

Kendal et. al., 2017

Task Scheduling

Task scheduling is the process of choosing which task or tasks to train on at each training step. Most MTL models make this decision in a very simple way, either training on all tasks at each step or randomly sampling a subset of tasks to train on, with some variation in these simple task schedulers. For example in the Task Scheduler proposed by (Sharma et al., 2017), a meta task-decider is trained to sample tasks with a training signal that encourages tasks with worse relative performance to be chosen more frequently.

Task scheduling visualization from (Sharma et al., 2017)

Gradient Modulation

One of the main challenges in MTL is negative transfer when the joint training of tasks hurts learning instead of helping it. From an optimization perspective, negative transfer manifests as the presence of conflicting task gradients. When two tasks have gradients pointing in opposing directions, following the gradient for one task will decrease the performance of the other task, while following the average of the two gradients means that neither task sees the same improvement it would in a single-task training setting.

Multi-task GREAT model (Sinha et al., 2018)

Knowledge Distillation

Knowledge Distillation is used to instill a single multi-task “student” network with the knowledge of many individual single-task “teacher” networks.

Two architectures from the Distral framework for RL (Teh et al., 2017)

What is Multi-Task Relationship Learning?

Multi-Task Relationship Learning

Grouping Tasks

Grouping Tasks provides an alternative solution to negative transfer in MTL. In this scheme, should two tasks exhibit negative transfer, it is suggested to separate their learning from the start. However, doing so requires significant computation time for trial and error in training networks jointly for various sets of tasks. A popular scheme to partition a group of tasks into clusters which each exhibit positive transfer between their respective tasks has been proposed by (Standley et al., 2019).

An example partitioning of a group of tasks into clusters with positive transfer (Standley et al., 2019)

Transfer Relationship Learning

Transfer learning already plays an important role in Computer Vision and Natural Language Processing where instead of building models from start, a pre-trained model from a similar task is used on a new task. Along the same lines, the motivation of the Transfer Relationship Learning method is to limit the number of tasks that have access to the full amount of supervised data (these are the source tasks) and to learn the remainder of tasks by transferring from the source tasks, with only a small amount of training data to train the decoder on top of the transferred feature extractor.

Task taxonomies for a collection of computer vision tasks as computed in Taskonomy (Zamir et al., 2018)

What are different Multi-Task Benchmarks?

Computer Vision Benchmarks

Natural Language Processing Benchmarks

Reinforcement Learning Benchmarks

Multi-Modal Benchmarks


In the paper “Multi-Task Learning With Deep Neural Networks: A Survey”, the author Micheal Crawshaw, presents a review of the field of Multi-Task Learning, covering the three broad directions of Architectures, Optimization Methods, and Task Relationship Learning along with key ideas from research in these areas. The author also provides readers insights into commonly used benchmarks used in various domains of Multi-Task Learning.



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Abhishek Bais

Abhishek Bais


Seasoned R&D EDA, Data Science Enthusiast, Cultural Explorer