-
Notifications
You must be signed in to change notification settings - Fork 9
/
tensorprod.m
109 lines (100 loc) · 3.23 KB
/
tensorprod.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
function [t, outind] = tensorprod(varargin)
% TENSORPROD - Multiply two or more arrays while summing over repeated
% indices. The indices for each array are indicated by strings.
%
% For example,
%
% C = tensorprod(A, 'ij', B, 'jk');
%
% computes a normal matrix product. The names of the indices are limited to
% a single character but are of no importance except to indicate repeated
% indices. The output indices come in the order they appear in the input
% arrays. Thus
%
% C = tensorprod(A, 'ki', B, 'ij');
%
% is exactly the same as the previous example. The output indices can be
% obtained by
%
% [C, ind] = tensorprod(A, 'ki', B, 'ij');
%
% and would be 'kj' in the latter example, whereas it would be 'ik' in the
% first example.
%
% No memory is spent on possible intermediate results but the speed is not
% as fast as a special-written implementation could be.
%
% BUGS: Imaginary parts of arrays are silently ignored.
%
% Author: Gunnar Farnebäck
% Medical Informatics
% Linköping University, Sweden
% Check for suspected call of old definition.
string_found = 0;
for k = 1:length(varargin)
if ischar(varargin{k})
string_found = 1;
end
end
if ~string_found
error(sprintf(['tensorprod has been changed. Update your code for ' ...
'the new definition\nor convert to the new function' ...
' outerprod.']))
end
if mod(length(varargin), 2) == 1
error('There must be an even number of arguments.');
end
translation_table = zeros(256, 1);
index_number = 1;
output_indices = [];
index_sizes = [];
arrays = cell(length(varargin) / 2, 1);
index_vectors = cell(length(varargin) / 2, 1);
index_vectors = {};
for k = 1:2:length(varargin)
x = varargin{k};
ind = varargin{k + 1};
if ~ischar(ind)
error(sprintf('Argument %d expected to be a string.', k + 1))
end
n = ndims(x);
if n == 2 && size(x, 2) == 1
n = 1;
end
if n ~= length(ind)
error(sprintf(['Mismatch between array dimension and number of ' ...
'indices, args %d and %d.'], k, k + 1))
end
indices = zeros(n, 1);
for m = 1:n
c = double(ind(m));
v = translation_table(c);
if v == 0
v = index_number;
index_number = index_number + 1;
translation_table(c) = v;
output_indices = [output_indices v];
index_sizes(v) = size(x, m);
else
output_indices = output_indices(output_indices ~= v);
if (index_sizes(v) ~= size(x, m))
error('All repeated indices must have the same length.');
end
end
indices(m) = v;
end
arrays{(k + 1) / 2} = x;
index_vectors{(k + 1) / 2} = indices;
end
if nargout > 1
outind = '';
for k = 1:length(output_indices)
outind = [outind char(find(translation_table == output_indices(k)))];
end
end
% If you need to speed up repeated tensorprod computations you can call
% tensorprodc directly with the arguments assembled here, but beware that it
% doesn't do any kind of error checking and might crash the whole Matlab
% process if you mess it up.
t = tensorprodc(index_sizes, output_indices, arrays{:}, index_vectors{:});