Kolmogorov–
Arnold Networks
Umar Jamil
Downloaded from: https://github.com/hkproj/kan-notes
License: Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0):
https://creativecommons.org/licenses/by-nc/4.0/legalcode
Not for commercial use
Umar Jamil – https://github.com/hkproj/kan-notes
Topics Prerequisites
• Review of Multilayer Perceptron • Basics of calculus (derivative)
• Introduction to data fitting • Basics of deep learning
(backpropagation)
• Bézier Curves
• B-Splines
• Universal Approximation Theorem
• Kolmogorov-Arnold Representation
Theorem
• MLPs vs KAN
• Properties
• Multi-layer KANs
• Parameters count: MLPs vs KANs
• Grid extension
• Interpretability
• Continual training
Umar Jamil – https://github.com/hkproj/kan-notes
The Multi-layer Perceptron (MLP)
A multilayer perceptron is a neural network made up of multiple layers of neurons, organized in a feed-forward way, with nonlinear activation functions in
between.
How does it work?
Class 1
Class 2
Class 3
Class 4
Class 5
Input
Hidden Layer 1 Hidden Layer 2 Output (logits)
Umar Jamil – https://github.com/hkproj/kan-notes
The Linear layer in PyTorch
Umar Jamil – https://github.com/hkproj/kan-notes
The Linear layer in detail
A linear layer in a MLP is made of a weight matrix and a bias matrix.
n1 n2 n3 n4 n5
The bias vector will be broadcasted to every
b= b
1
row in the 𝑋𝑊 𝑇 table.
𝑧1 = (𝑟1 + 𝑏1 ) = (σ3𝑖=1 𝑎𝑖 𝑤𝑖 + 𝑏1 ) (1, 5)
+
f1 f2 f3 f1 f2 f3 f4 f5 f1 f2 f3 f4 f5
a1 a2 a3 r1 z1
Item 1 𝑶 = 𝑿𝑾𝑻 + 𝒃 Item 1 Item 1
Item 2 Item 2 Item 2
Item 3 n1 n2 n3 n4 n5 Item 3 Item 3
w1
X= 𝑾𝑻 = w2
𝑿𝑾𝑻 = O=
(10, 3) (3, 5) w3 (10, 5) (10, 5)
Item 10 Item 10 Item 10
Umar Jamil – https://github.com/hkproj/kan-notes
Why do we need activation functions?
After each Linear layer, we usually apply a nonlinear activation function. Why?
𝑶𝟏 = 𝒙𝑾1𝑻 + 𝒃𝟏
𝑶𝟐 = (𝑶𝟏 )𝑾𝑻2 + 𝒃𝟐
𝑶𝟐 = (𝒙𝑾1𝑻 + 𝒃𝟏 )𝑾𝑻2 + 𝒃𝟐
𝑶𝟐 = 𝒙𝑾1𝑻 𝑾𝑻2 + 𝒃𝟏 𝑾𝑻2 + 𝒃𝟐
As you can see, if we do not apply any activation functions, the output will just be a linear combination of the inputs, which means that our MLP will not be
able to learn any non-linear mapping between the input and output, which represents most of the real-world data.
Umar Jamil – https://github.com/hkproj/kan-notes
Introduction to data fitting
Imagine you’re making a 2D game and you want animate your sprite (character) to pass through a series of points. One way would be to make a straight line
from one point to the next, but that wouldn’t look so good. What if you could create a smoother path, like the one below?
Umar Jamil – https://github.com/hkproj/kan-notes
Smooth curves through polynomial curves
How to find the equation of such a smooth curve?
One way is to write the generic equation of a polynomial curve and force it to pass through the series of points to get the coefficients of the equation.
We have 4 points, so we can make a system of equations with 4 equations, which means we can solve for 4 variables: yes, we get a polynomial with degree 3.
𝑦 = 𝑎𝑥 3 + 𝑏𝑥 2 + 𝑐𝑥 + 𝑑
We can write our system of equations as follows and solve to find the equation of the curve:
5 = 𝑎(0)3 +𝑏(0)2 +𝑐 0 + 𝑑
1 = 𝑎(1)3 +𝑏(1)2 + 𝑐(1) + 𝑑
3 = 𝑎(2)3 +𝑏(2)2 + 𝑐(2) + 𝑑
2 = 𝑎(5)3 +𝑏(5)2 + 𝑐(5) + 𝑑
Umar Jamil – https://github.com/hkproj/kan-notes
What if I have hundreds of points?
If you have N points, you need a polynomial of degree N – 1 if you want the line to pass through all those points. But as you can see, when we have lots of
points, the polynomial starts getting crazy on the extremes. We wouldn’t want the character in our 2D game to go out of the screen while we’re animating it,
right?
Thankfully, someone took the time to solve this problem, because we have Bézier curves!
Source: https://arachnoid.com/polysolve/
Umar Jamil – https://github.com/hkproj/kan-notes
Bézier curves
A Bézier curves is a parametric curve (which means that all the coordinates of the curve depend on an independent variable 𝑡, between 0 and 1).
For example, given two points, we can calculate the linear B curve as the following interpolation:
𝑩 𝑡 = 𝑷0 + 𝑡 𝑷1 − 𝑷0 = 1 − 𝑡 𝑷0 + 𝑡𝑷1
Given three points, we can calculate the quadratic Bézier curve that interpolates them.
Source: Wikipedia
𝑸0 𝑡 = 1 − 𝑡 𝑷0 + 𝑡𝑷1
𝑸1 𝑡 = 1 − 𝑡 𝑷1 + 𝑡𝑷2
𝑩 𝑡 = 1 − 𝑡 𝑸0 + 𝑡𝑸1
= 1 − 𝑡 1 − 𝑡 𝑷0 + 𝑡𝑷1 + 𝑡 1 − 𝑡 𝑷1 + 𝑡𝑷2
= 1 − 𝑡 2 𝑷0 + 2 1 − 𝑡 𝑡𝑷1 + 𝑡 2 𝑷2
With four points, we can proceed with a similar reasoning.
Umar Jamil – https://github.com/hkproj/kan-notes
Bézier curves: going deeper
Yes, we can go deeper! If we have 𝑛 + 1 points, we can find the 𝑛 degree Bézier curve using the following formula
𝑛 𝑛
𝑛 𝑛−𝑖 𝑖
𝑩 𝑡 = 1−𝑡 𝑡 𝑷𝑖 = 𝑏𝑖,𝑛 (𝑡)𝑷𝑖
𝑖
𝑖=0 𝑖=0
Bernstein basis polynomials
Blue: 𝑏0,3 𝑡
Green: 𝑏1,3 𝑡
Red: 𝑏2,3 𝑡
Cyan: 𝑏3,3 𝑡
Binomial coefficients
𝑛 𝑛!
=
𝑖 𝑖! 𝑛 − 𝑖 !
Source: Wikipedia
Umar Jamil – https://github.com/hkproj/kan-notes
From Bézier curves to B-Splines
If you have lots of points (say n), you need a Bézier curve with a degree n-1 to approximate it well, but that can be quite complicated computationally to
calculate.
Someone wise thought: why don’t we stitch together many Bézier curves between all these points, instead of one big Bézier curve that interpolates all of
them?
Source: Wikipedia
Umar Jamil – https://github.com/hkproj/kan-notes
B-splines in detail
A 𝑘-degree B-Spline curve that is defined by 𝑛 control points, will consist of 𝑛 − 𝑘 Bézier curves.
For example, if we want to use a quadratic Bézier curve and we have 6 points, we need 6 − 2 = 4 Bézier curves.
In this case we have n=6 and k=2
Source: Wolfram Alpha
Umar Jamil – https://github.com/hkproj/kan-notes
B-splines in detail
The degree of our B-Spline also tells what kind of continuity we get.
Source: MIT
Umar Jamil – https://github.com/hkproj/kan-notes
Calculating B-splines: algorithm
Source: MIT
Umar Jamil – https://github.com/hkproj/kan-notes
B-Splines: basis functions
𝑁0,2 𝑁5,2
𝑁2,2 𝑁3,2
𝑁1,2 𝑁4,2
Umar Jamil – https://github.com/hkproj/kan-notes
B-splines: local control
Moving a control point only changes the curve locally (in the proximity of the control point), leaving the adjacent Bezier curves unchanged!
Umar Jamil – https://github.com/hkproj/kan-notes
Universal Approximation Theorem
We can think of neural networks as functions approximators. Usually, we have access to some data points generated by an ideal function that we do not have
access. The goal of training a neural network is to approximate this ideal function (that we do not have access).
But how do we know if a neural network is powerful enough to model our ideal function? What can we say about the expressive power of neural networks?
This is what the universal approximation theorem is all about: it is a series of results that put limits on what neural networks can learn.
It has been proven that neural networks with a certain width (number of neurons) and depth (number of layers) can approximate any continuous function if
using specific non-linear activation functions, for example the ReLU function. Check Wikipedia for more theoretical results.
I want to emphasize what it means to be a universal approximator: it means that given an ideal function (or a family of functions) that models the training data,
the network can learn to approximate it as good as we want, that is, given an error 𝜖, we can always find an approximate function that is close to the ideal
function within this error limit.
This is however a theoretical result; it doesn’t tell us how to do it practically. On a practical level, we have many problems:
• Achieving good approximations may take enormous amounts of computational power
• We may need a large big quantity of training data
• Our hardware may not be able to represent certain weights in 32 bit
• Our optimizer may remain stuck in a local minima
So as you can see, just because a neural network can learn anything, doesn’t mean we are be able to learn it in practice. But at least we know that the limits
are practical.
Umar Jamil – https://github.com/hkproj/kan-notes
Kolmogorov-Arnold representation theorem
Umar Jamil – https://github.com/hkproj/kan-notes
Kolmogorov-Arnold Networks
This can network can be thought of as two layers applied in sequence:
𝑜1 • The first layer maps 2 input features into 5 output features.
• The second layer maps 5 input features into 1 output feature.
𝑛=2
2𝑛 + 1 = 5
𝜑1 𝜑2 𝜑3 𝜑4 𝜑5
ℎ1 ℎ2 ℎ3 ℎ4 ℎ5 We sum the output of the learnable functions
Instead of having learnable weights,
we have learnable functions
𝜑1,1 𝜑2,1 𝜑3,1 𝜑4,1 𝜑5,1 𝜑1,2 𝜑2,2 𝜑3,2 𝜑4,2 𝜑5,2
𝑥1 𝑥2
Umar Jamil – https://github.com/hkproj/kan-notes
MLP vs KAN
Umar Jamil – https://github.com/hkproj/kan-notes
Multi-layer KAN
Layer 2
5 input features, 1 output features
total of 5 functions to “learn”
Layer 1
2 input features, 5 output features
total of 10 functions to “learn”
Umar Jamil – https://github.com/hkproj/kan-notes
Implementation details
Umar Jamil – https://github.com/hkproj/kan-notes
Parameters count
Compared to MLP, we also have (G+k) parameters for each activation, because we need to learn where to put the control points for the B-Splines.
Umar Jamil – https://github.com/hkproj/kan-notes
Grid extension
We can increase the number of “control points” in the B-Spline to give it more “degrees of freedom” to better approximate more complex functions, meaning
that we can extend the grid of an existing pre-trained network.
Umar Jamil – https://github.com/hkproj/kan-notes
Interpretability
Umar Jamil – https://github.com/hkproj/kan-notes
Continual learning
Umar Jamil – https://github.com/hkproj/kan-notes
Thanks for watching!
Don’t forget to subscribe for
more amazing content on AI
and Machine Learning!
Umar Jamil – https://github.com/hkproj/kan-notes