Differentiating Matrix Expressions
DiffMatic
operates on expressions represented in Ricci-notation [1], which enables a precise treatment of tensor derivatives. After computing derivatives in Ricci-notation, the resulting expression can be converted back to standard notation. Derivatives written in standard matrix notation are compact and suitable for use with common libraries for linear algebra. DiffMatic
can also generate code for evaluating derivatives directly.
Matrices and Vectors as Tensors
Matrix expressions written in Julia syntax are automatically converted to Ricci-notation when using the types provided by DiffMatic
. Matrices, vectors and scalars are created using the corresponding macros:
using DiffMatic
@matrix A B C
@vector x y z
@scalar a b c
All of the above variables are represented using the same type. The modes of a tensor can be seen by printing the variable:
A
# output
A¹₂
Matrices have one upper and one lower index. Column vectors have one upper index:
x
# output
x¹
Scalars have no indices.
In Ricci-notation, and also Einstein notation, indices are usually written with letters. We use numbers for convenience.
Creating Tensor Expressions
Many common functions are overloaded for the tensor type. This enables creationg of expressions in Ricci-notation using common standard Julia syntax. Tensor indices are automatically updated as needed.
A * x # standard matrix-vector multiplication
# output
A¹₄x⁴
The same index appearing both as an upper and as a lower index indicates a contraction over that index. Multiplying with a scalar does not result in a contraction:
c * A
# output
cA¹₂
Supported operators and functions when creating expressions:
- Basic operators
+
,-
,'
,*
,^
,abs
,sin
andcos
- Element-wise operators
sin.
,cos.
,abs.
,.*
and.^
- Vector 1-norm and 2-norm can be computed with
norm1
andnorm2
- Sums of vectors can be computed with
sum
. - Matrix traces can be computed with
tr
.
See Creating Expressions for more examples.
Computing Derivatives
Matrix derivatives are computed on expressions in Ricci-notation directly. DiffMatic
provides specialized functions for computing gradients, Jacobians and Hessians respectively.
gradient
and hessian
require a scalar as input:
expr = 2 * x' * A * B * x
g = gradient(expr, x) # The second argument denotes the differentiation variable.
# output
2A₄⁶B₆⁸x⁴ + 2B₆⁷x₇A⁸⁶
H = hessian(expr, x)
# output
2A₉⁶B₆⁸ + 2B₆₉A⁸⁶
jacobian
requires a column vector as input:
J = jacobian(x' * x * A * x, x)
# output
x₃x³A¹₆ + 2A¹₅x⁵x₆
See Derivatives in Standard Notation for more examples.
Converting Tensor Expression to Standard Notation
An expression in Ricci-notation can be converted back to standard notation using to_std
. DiffMatic
can output a standard expression in string format or as a Julia function.
String Output
An expression can be converted to standard matrix notation and retrieved as a string by passing StdStr
as a keyword argument to to_std
:
to_std(J; format = StdStr())
# output
"xᵀxA + 2Axxᵀ"
Special notation used in the output:
- "⊙": Element-wise multiplication.
- "diag(x)": Diagonal matrix with "x" on the diagonal.
- "vec(1)": Vector consisting of 1:s.
- "sgn(x)": The signum function applied element-wise.
- "tr(X)": Trace of the matrix "X".
- "abs(x)": Element-wise absolute value of "x".
Julia Function
A Julia function is generated by passing JuliaFunc
as a keyword argument to to_std
. The function is returned in the form of an expression:
to_std(g; format = JuliaFunc())
# output
quote
#= ... =#
function (B, A, x)
#= ... =#
#= ... =#
return 2 * (transpose(B) * (transpose(A) * x)) + 2 * ((A * B) * x)
end
end
The generated function can be compiled using eval
:
g_fun = eval(to_std(g; format = JuliaFunc()))
An = Float64[1 1 1; 2 2 2; 3 3 3]
Bn = Float64[4 4 4; 5 5 5; 6 6 6]
xn = Float64[1; 1; 1]
g_fun(Bn, An, xn)
# output
3-element Vector{Float64}:
270.0
360.0
450.0
- [1]
- S. Laue, M. Mitterreiter and J. Giesen. Computing Higher Order Derivatives of Matrix and Tensor Expressions. In: Advances in Neural Information Processing Systems, Vol. 31 (Curran Associates, Inc., 2018).