ODIN
grappa_code.h
1 #include "grappa.h"
2 #include "data.h"
3 #include "controller.h"
4 
5 #include <odindata/linalg.h>
6 
7 
8 STD_string grappa_postlabel(recoDim dim) {
9  return STD_string("grappaweights_")+recoDimLabel[dim];
10 }
11 
12 
14 
15 ivector next_neighbours_offset(unsigned int reductionFactor, int numof_neighb, int ired) {
16  Log<Reco> odinlog("","next_neighbours_offset");
17 
18  ivector result(numof_neighb);
19 
20  int negfirst=(ired<int(reductionFactor/2)); // determine initial direction
21 
22  for(int i=0; i<numof_neighb; i++) {
23  int i2=i/2; // interleaved positive and negative direction
24  if((i+negfirst)%2) result[i]=-ired-1-i2*reductionFactor; // negative direction
25  else result[i]=reductionFactor-ired-1+i2*reductionFactor; // positive direction
26  }
27 
28  ODINLOG(odinlog,normalDebug) << "result[" << reductionFactor << "/" << ired << "]" << result << STD_endl;
29 
30  return result;
31 }
32 
34 
35 bool measured_line(const ivector& indexvec, int index) {
36  for(unsigned int i=0; i<indexvec.size(); i++) {
37  if(indexvec[i]==index) return true;
38  }
39  return false;
40 }
41 
43 
44 void set_separate_index(RecoCoord& coord, recoDim orthoDim, int iortho) {
45  coord.set_mode(RecoIndex::separate, orthoDim);
46  coord.index[orthoDim]=iortho;
47 }
48 
49 
51 
52 
53 template<recoDim interpolDim, recoDim orthoDim, int interpolIndex, int orthoIndex, class Ignore>
55 
56  reduction_factor=0;
57  reduction_factor.set_description("Reduction factor");
58  append_arg(reduction_factor,"reduction_factor");
59 
60  neighbours_read=5;
61  neighbours_read.set_cmdline_option("gr").set_description("Number of neigbours in read direction used for GRAPPA interpolation");
62  append_arg(neighbours_read,"neighbours_read");
63 
64  neighbours_phase=2;
65  neighbours_phase.set_cmdline_option("gp").set_description("Number of neigbouring measured k-space lines used for GRAPPA interpolation");
66  append_arg(neighbours_phase,"neighbours_phase");
67 
68  svd_trunc=0.001;
69  svd_trunc.set_cmdline_option("gs").set_description("Truncation value of SVD (i.e. regularization) when calculating GRAPPA weights");
70  append_arg(svd_trunc,"svd_trunc");
71 
72  discard_level=0.0; //0.01;
73  discard_level.set_cmdline_option("gd").set_description("Fraction of k-psace points to discard in auto-calibration lines because of high residuals");
74  append_arg(discard_level,"discard_level");
75 }
76 
77 
79 
80 
81 template<recoDim interpolDim, recoDim orthoDim, int interpolIndex, int orthoIndex, class Ignore>
83  Log<Reco> odinlog(c_label(),"process");
84 
85  RecoCoord coordcopy(rd.coord()); // local copy to modify coord according to ignored dims
86  Ignore::modify(RecoIndex::ignore, coordcopy);
87 
88  RecoData rdweights(coordcopy);
89  if(!calc_weights(rdweights.data(Rank<5>()), coordcopy, rd.data(Rank<4>()))) return false;
90 
91  controller.post_data(grappa_postlabel(interpolDim), rdweights);
92  ODINLOG(odinlog,normalDebug) << "Posted rdweights=" << rdweights.coord().print() << STD_endl;
93 
94  return execute_next_step(rd,controller);
95 }
96 
98 
99 
100 template<recoDim interpolDim, recoDim orthoDim, int interpolIndex, int orthoIndex, class Ignore>
102  Log<Reco> odinlog(c_label(),"query");
103  if(context.mode==RecoQueryContext::prep) {
104  context.controller.announce_data(grappa_postlabel(interpolDim));
105  RecoCoord coordcopy(context.coord);
106  Ignore::modify(RecoIndex::ignore, coordcopy);
107  if(!measlines.init(coordcopy, context.controller)) return false;
108  }
109  return RecoStep::query(context);
110 }
111 
113 
114 
115 template<recoDim interpolDim, recoDim orthoDim, int interpolIndex, int orthoIndex, class Ignore>
117  Log<Reco> odinlog(c_label(),"calc_weights");
118 
119  Range all=Range::all();
120 
121  ODINLOG(odinlog,normalDebug) << "neighbours_read/neighbours_phase=" << neighbours_read << "/" << neighbours_phase << STD_endl;
122 
123  TinyVector<int,4> inshape=trainingdata.shape();
124  int nChannels=inshape(0);
125  int sizeOrtho=inshape(1+orthoIndex);
126  int sizeInterpol=inshape(1+interpolIndex);
127  int sizeRead=inshape(3);
128  ODINLOG(odinlog,normalDebug) << "nChannels/sizeOrtho/sizeInterpol/sizeRead/reduction_factor=" << nChannels << "/" << sizeOrtho << "/" << sizeInterpol << "/" << sizeRead << "/" << reduction_factor << STD_endl;
129 
130  // Shape for 3D volumes
131  TinyVector<int,3> shape3d(inshape(1),inshape(2),inshape(3));
132 
133  // shape to calculate index vector within interpolation net
134  TinyVector<int,3> netshape(nChannels, neighbours_phase, neighbours_read);
135 
136  int ncols=product(netshape); // Size of the interpolation net
137  ComplexData<1> weights_net_vec(ncols); // Interpolation net around a particular root node
138 
139  weights.resize(nChannels, reduction_factor-1, nChannels, neighbours_phase, neighbours_read); // the result
140  weights=STD_complex(0.0);
141 
142  for(int ired=0; ired<(reduction_factor-1); ired++) { // loop over (R-1) missing lines
143 
144  // Mask for the root nodes available to calculate the weights
145  Data<int,4> acl_mask(inshape);
146  acl_mask=0;
147 
148  RecoCoord aclcoord(trainingcoord);
149  for(int iortho=0; iortho<sizeOrtho; iortho++) {
150 
151  // GRAPPA encoding might be different in each of the orthogonal phase encoding partitions
152  set_separate_index(aclcoord, orthoDim, iortho);
153 
154  ivector aclindices=acl_lines(aclcoord, neighbours_phase, ired);
155  ODINLOG(odinlog,normalDebug) << "aclindices(" << ired << ", " << aclcoord.print() << ")=" << aclindices << STD_endl;
156 
157  for(int iacl=0; iacl<int(aclindices.size()); iacl++) {
158  TinyVector<int,2> aclindex;
159  aclindex(orthoIndex)=iortho;
160  aclindex(interpolIndex)= aclindices[iacl];
161  acl_mask(all, aclindex(0), aclindex(1), Range( neighbours_read/2, sizeRead-1-neighbours_read/2 ))=1;
162  }
163  }
164 
165  TinyVector<int,3> acl_signal_shape(shape3d); // used only to calculate index vector of root nodes
166 
167  // Offset of neighbouring lines used to calculate weights for this particular reduction index
168  ivector neighboffset=next_neighbours_offset(reduction_factor, neighbours_phase, ired);
169  ODINLOG(odinlog,normalDebug) << "neighboffset(" << ired << ")=" << neighboffset.printbody() << STD_endl;
170 
171  for(int ichan_dst=0; ichan_dst<nChannels; ichan_dst++) { // calculate weights separately for each channel
172 
173  int nrows=sum(acl_mask(ichan_dst,all,all,all));
174  ODINLOG(odinlog,normalDebug) << "ncols/nrows=" << ncols << "/" << nrows << STD_endl;
175  if(ncols>nrows) {
176  ODINLOG(odinlog,errorLog) << "Not enough auto-calibration data for " << netshape << " interpolation net" << STD_endl;
177  return false;
178  }
179 
180  // fill acl_signal_vec with signal values from ACLs
181  ComplexData<1> acl_signal_vec(nrows); // Will hold root node data of a particular channel
182  Array<TinyVector<int,3>,1> maskindex(nrows); // cache for coordinates of signal values contributing to SVD
183  int irow=0;
184  for(int i=0; i<product(acl_signal_shape); i++) {
185  TinyVector<int,3> rowindex=index2extent(acl_signal_shape, i);
186  TinyVector<int,4> aclindex(ichan_dst, rowindex(0), rowindex(1), rowindex(2));
187  if(acl_mask(aclindex)) {
188  maskindex(irow)=rowindex;
189  STD_complex aclval=trainingdata(aclindex);
190  if(cabs(aclval)==0.0) {
191  ODINLOG(odinlog,warningLog) << "Zero acl at " << aclindex << STD_endl;
192  }
193  acl_signal_vec(irow)=aclval;
194  irow++;
195  }
196  }
197 
198  // fill Matrix for SVD
199  ComplexData<2> Matrix(nrows,ncols);
200  for(int irow=0; irow<nrows; irow++) { // loop over root nodes
201  TinyVector<int,3> aclindex=maskindex(irow);
202 
203  for(int icol=0; icol<ncols; icol++) { // loop over interpolation net, i.e. over neighbourhood of root node
204  TinyVector<int,3> windex=index2extent(netshape,icol);
205 
206  int ichan_src=windex(0);
207  int ipoloffset=neighboffset[windex(1)];
208  int readoffset=windex(2)-neighbours_read/2; // symmetrical about root node
209 
210  TinyVector<int,4> trainingindex;
211  trainingindex(0)=ichan_src;
212  trainingindex(1)=aclindex(0);
213  trainingindex(2)=aclindex(1);
214  trainingindex(3)=aclindex(2)+readoffset;
215 
216  trainingindex(1+interpolIndex)+=ipoloffset; // take neighbour in interpolation direction
217 
218  STD_complex trainingval=trainingdata(trainingindex);
219  if(cabs(trainingval)==0.0) {
220  ODINLOG(odinlog,warningLog) << "Zero trainingdata at " << trainingindex << STD_endl;
221  }
222 
223  Matrix(irow,icol)=trainingval;
224  }
225  }
226 
227  // solve system of linear equations using complex SVD
228  weights_net_vec=solve_linear(Matrix, acl_signal_vec, svd_trunc);
229 
230 
231  // Use residuals to eliminate noisy ACLs (Huo et al., JMRI 2008, 27:1412)
232  if(discard_level>0.0) {
233  Data<float,1> residuals(cabs(matrix_product(Matrix, weights_net_vec)-acl_signal_vec));
234  int ndiscard=int(discard_level*nrows+0.5);
235  if(ncols>(nrows-ndiscard)) {
236  ODINLOG(odinlog,warningLog) << "Limiting ndiscard to " << ndiscard << " for sufficient auto-calibration data" << STD_endl;
237  ndiscard=nrows-ncols;
238  }
239  for(int i=0; i<ndiscard; i++) {
240  int irow_max=maxIndex(residuals)(0);
241  acl_signal_vec(irow_max)=STD_complex(0.0);
242  Matrix(irow_max,all)=STD_complex(0.0);
243  residuals(irow_max)=0.0;
244  }
245 
246  // solve again
247  weights_net_vec=solve_linear(Matrix, acl_signal_vec, svd_trunc);
248  }
249 
250  for(int icol=0; icol<ncols; icol++) {
251  TinyVector<int,3> windex=index2extent(netshape,icol);
252  weights(ichan_dst, ired, windex(0), windex(1), windex(2))=weights_net_vec(icol);
253  }
254 
255  } // end loop over channels
256 
257  } // end loop over reduction inidices
258 
259  ODINLOG(odinlog,normalDebug) << "cabs(sum(weights" << trainingcoord.print() << "))=" << cabs(sum(weights)) << STD_endl;
260 
261  return true;
262 }
263 
265 
266 
267 
268 template<recoDim interpolDim, recoDim orthoDim, int interpolIndex, int orthoIndex, class Ignore>
270  Log<Reco> odinlog(c_label(),"acl_lines");
271 
272  ivector neighboffset=next_neighbours_offset(reduction_factor, numof_neighb, ired);
273  ODINLOG(odinlog,normalDebug) << "neighboffset=" << neighboffset.printbody() << STD_endl;
274 
275  ivector indexvec=measlines.get_indices(coord);
276  ODINLOG(odinlog,normalDebug) << "indexvec(" << coord.print() << ")=" << indexvec.printbody() << STD_endl;
277 
278  STD_list<int> indexlist;
279  for(unsigned int i=0; i<indexvec.size(); i++) {
280  int iline=indexvec[i];
281  bool acl=true;
282  for(int ineighb=0; ineighb<numof_neighb; ineighb++) { // check whether all neighbours are there
283  int neighbindex=iline+neighboffset[ineighb];
284  if(!measured_line( indexvec, neighbindex)) acl=false;
285  }
286  if(acl) indexlist.push_back(iline);
287  }
288 
289  return list2vector(indexlist);
290 }
291 
292 
293 
294 
295 
297 
298 
299 template<recoDim interpolDim, recoDim orthoDim, int interpolIndex, int orthoIndex>
301 
302  reduction_factor=0;
303  reduction_factor.set_description("Reduction factor");
304  append_arg(reduction_factor,"reduction_factor");
305 
306  keep_shape=false;
307  keep_shape.set_description("Do not correct shape of interpolated k-space according to image space size");
308  append_arg(keep_shape,"keep_shape");
309 
310 }
311 
312 
313 template<recoDim interpolDim, recoDim orthoDim, int interpolIndex, int orthoIndex>
315  Log<Reco> odinlog(c_label(),"process");
316 
317  if(!controller.data_announced(grappa_postlabel(interpolDim))) {
318  ODINLOG(odinlog,errorLog) << "GRAPPA weights not available" << STD_endl;
319  return false;
320  }
321 
322  ODINLOG(odinlog,normalDebug) << "reduction_factor=" << reduction_factor << STD_endl;
323 
324  if(!keep_shape) {
325  if(!correct_shape(rd, controller)) return false; // adjust shape to account for incomplete sampling at end of k-space
326  }
327 
328  ODINLOG(odinlog,normalDebug) << "Requesting weights for " << rd.coord().print() << STD_endl;
329  RecoData rdweights(rd.coord());
330  if(!controller.inquire_data(*this, grappa_postlabel(interpolDim), rdweights)) return false;
331  ODINLOG(odinlog,normalDebug) << "Retrieved weights with " << rdweights.coord().print() << STD_endl;
332 
333  const ComplexData<5>& weights=rdweights.data(Rank<5>());
334  ODINLOG(odinlog,normalDebug) << "cabs(sum(weights(" << rdweights.coord().print() << ")))=" << cabs(sum(weights)) << STD_endl;
335 
336 
337  ComplexData<4>& kspace=rd.data(Rank<4>());
338  TinyVector<int,4> kspaceshape=kspace.shape();
339  int nChannels=kspaceshape(0);
340  int sizeOrtho=kspaceshape(1+orthoIndex);
341  int sizeInterpol=kspaceshape(1+interpolIndex);
342  int sizeRead=kspaceshape(3);
343  ODINLOG(odinlog,normalDebug) << "nChannels/sizeOrtho/sizeInterpol/sizeRead/reduction_factor=" << nChannels << "/" << sizeOrtho << "/" << sizeInterpol << "/" << sizeRead << "/" << reduction_factor << STD_endl;
344 
345 
346  int neighbours_phase=weights.extent(3);
347  int neighbours_read=weights.extent(4);
348  ODINLOG(odinlog,normalDebug) << "neighbours_phase/neighbours_read=" << neighbours_phase << "/" << neighbours_read << STD_endl;
349 
350  CoordCountMap interpolated_coords;
351  if(rd.interpolated) {
352  measlines.update(*(rd.interpolated)); // take previously interpolated coords into account
353  interpolated_coords=(*(rd.interpolated)); // merge previously interpolated coords
354  }
355 
356  // shape to calculate index vector within interpolation net
357  TinyVector<int,3> netshape(nChannels, neighbours_phase, neighbours_read);
358 
359  // separate array for interpolated values to avoid recursive summation
360  ComplexData<4> kspace_interpol(kspaceshape);
361  kspace_interpol=STD_complex(0.0);
362 
363 
364  for(int iortho=0; iortho<sizeOrtho; iortho++) {
365 
366  // GRAPPA encoding might be different in each of the orthogonal phase encoding partitions
367  RecoCoord meascoord(rd.coord());
368  set_separate_index(meascoord, orthoDim, iortho);
369 
370  ivector indexvec=measlines.get_indices(meascoord);
371  ODINLOG(odinlog,normalDebug) << "indexvec(" << meascoord.print() << ")=" << indexvec.printbody() << STD_endl;
372 
373  if(indexvec.size()) { // discard empty lines in orthogonal direction
374 
375  for(int ipol=0; ipol<sizeInterpol; ipol++) {
376 
377  if(!measured_line(indexvec, ipol)) {
378 
379  int ired=reduction_index(indexvec, sizeInterpol, ipol);
380  ODINLOG(odinlog,normalDebug) << "ired(" << ipol << ")=" << ired << STD_endl;
381 
382  if(ired>=0 && ired<(reduction_factor-1)) { // omit lines which are discarded due to partial Fourier
383 
384  set_separate_index(meascoord, interpolDim, ipol);
385  interpolated_coords[meascoord]++;
386  ODINLOG(odinlog,normalDebug) << "Interpolating coord " << meascoord.print() << STD_endl;
387 
388  ivector neighboffset=next_neighbours_offset(reduction_factor, neighbours_phase, ired);
389  ODINLOG(odinlog,normalDebug) << "neighboffset" << neighboffset << STD_endl;
390 
391  for(int iread=neighbours_read/2; iread<(sizeRead-neighbours_read/2); iread++) { // avoid interpolation of edges without neighbours
392  for(int ichan=0; ichan<nChannels; ichan++) {
393 
394  TinyVector<int,4> kspaceindex;
395  kspaceindex(0)=ichan;
396  kspaceindex(1+orthoIndex)=iortho;
397  kspaceindex(1+interpolIndex)=ipol;
398  kspaceindex(3)=iread;
399 
400  STD_complex interpolval(0.0);
401 
402  // iterate over 3D interpolation net to accumulate signal value at root node
403  for(int i=0; i<product(netshape); i++) {
404  TinyVector<int,3> netindex=index2extent(netshape,i); // index within net
405  int jchan=netindex(0);
406  int jphase=netindex(1);
407  int jread=netindex(2);
408 
409  int offset_ipol=neighboffset[jphase];
410  int offset_read=jread-neighbours_read/2; // symmetrical about root node
411 
412  int src_ipol =ipol+offset_ipol;
413  int src_iread =iread+offset_read;
414 
415  if(src_ipol>=0 && src_ipol<sizeInterpol && src_iread>=0 && src_iread<sizeRead) { // check if src index is within k-space
416 
417  TinyVector<int,4> srcindex;
418  srcindex(0)=jchan;
419  srcindex(1+orthoIndex)=iortho; // Use only src points from the same partition in orthogonal direction
420  srcindex(1+interpolIndex)=src_ipol;
421  srcindex(3)=src_iread;
422 
423  STD_complex srcval=kspace(srcindex);
424 /*
425  if(cabs(srcval)==0.0) {
426  ODINLOG(odinlog,warningLog) << "Zero srcval at " << srcindex << " while interpolating " << kspaceindex << STD_endl;
427  }
428 */
429 
430  interpolval += srcval * weights(ichan,ired,jchan,jphase,jread); // linear interpolation of kspace
431  }
432  }
433 
434  kspace_interpol(kspaceindex)=interpolval;
435  }
436  }
437  }
438  }
439  }
440  }
441  }
442 
443  kspace=kspace+kspace_interpol;
444 
445  rd.interpolated=&interpolated_coords; // Inform subsequent steps about the interpolated coordinates
446 
447  return execute_next_step(rd,controller);
448 }
449 
451 
452 template<recoDim interpolDim, recoDim orthoDim, int interpolIndex, int orthoIndex>
454  Log<Reco> odinlog(c_label(),"correct_shape");
455 
456  Range all=Range::all();
457 
458  ComplexData<4>& inkspace=rdkspace.data(Rank<4>());
459  TinyVector<int,4> inshape=inkspace.shape();
460 
461  int nlines_dst=controller.image_size()(interpolIndex);
462  int nlines_src=inshape(1+interpolIndex);
463  if(nlines_dst==nlines_src) return true;
464 
465  if(nlines_dst<nlines_src) nlines_dst=nlines_src;
466  TinyVector<int,4> outshape(inshape);
467  outshape(1+interpolIndex)=nlines_dst;
468 
469  ODINLOG(odinlog,normalDebug) << "inshape/outshape=" << inshape << "/" << outshape << STD_endl;
470 
471  ComplexData<4> outkspace(outshape);
472  outkspace=STD_complex(0.0);
473 
474  Range srcrange(0,nlines_src-1);
475  if(interpolDim==line) outkspace(all,all,srcrange,all)=inkspace(all,all,srcrange,all);
476  if(interpolDim==line3d) outkspace(all,srcrange,all,all)=inkspace(all,srcrange,all,all);
477  rdkspace.data(Rank<4>()).reference(outkspace);
478 
479  return true;
480 }
481 
482 
483 
485 
486 
487 template<recoDim interpolDim, recoDim orthoDim, int interpolIndex, int orthoIndex>
489  Log<Reco> odinlog(c_label(),"query");
490  if(context.mode==RecoQueryContext::prep) {
491  if(!measlines.init(context.coord, context.controller)) return false;
492  }
493 
494  return RecoStep::query(context);
495 }
496 
498 
499 template<recoDim interpolDim, recoDim orthoDim, int interpolIndex, int orthoIndex>
500 int RecoGrappa<interpolDim,orthoDim,interpolIndex,orthoIndex>::reduction_index(const ivector& indexvec, int sizePhase, int iphase) const {
501  Log<Reco> odinlog(c_label(),"reduction_index");
502 
503  ODINLOG(odinlog,normalDebug) << "reduction_factor/sizePhase/iphase=" << reduction_factor << "/" << sizePhase << "/" << iphase << STD_endl;
504 
505  int posnext=-1;
506  int negnext=-1;
507 
508  // search for next scanned line in positive direction
509  for(int i=iphase; i<sizePhase; i++) {
510  if(measured_line(indexvec, i)) {
511  posnext=i-iphase;
512  break;
513  }
514  }
515 
516  // search for next scanned line in negative direction
517  for(int i=iphase; i>=0; i--) {
518  if(measured_line(indexvec, i)) {
519  negnext=iphase-i;
520  break;
521  }
522  }
523 
524  ODINLOG(odinlog,normalDebug) << "posnext/negnext(" << iphase << ")=" << posnext << "/" << negnext << STD_endl;
525 
526  if(posnext<0 || negnext<0) return -1; // Measured line is missing in one or more directions -> Line is not surrounded by measured lines
527 
528  return negnext-1;
529 }
Definition: tjlog.h:218
void announce_data(const STD_string &label)
Definition: controller.h:111
bool inquire_data(const RecoStep &caller, const STD_string &label, RecoData &data)
void post_data(const STD_string &label, const RecoData &data)
Definition: controller.h:116
bool data_announced(const STD_string &label)
Definition: controller.h:121
TinyVector< int, 3 > image_size() const
const CoordCountMap * interpolated
ComplexData< 1 > & data(Rank< 1 >) const
Definition: odinreco/data.h:91
RecoCoord & coord()
grappaweights, grappaweights3d, grappaweightstempl, grappaweightstempl3d: Calculate GRAPPA weights in...
Definition: grappa.h:35
grappa, grappa3d: Perform GRAPPA interpolation in dimension 'line3d'
Definition: grappa.h:76
virtual bool query(RecoQueryContext &context)
STD_string printbody() const
Data< float, 1 > solve_linear(const Data< float, 2 > &A, const Data< float, 1 > &b, float sv_truncation=0.0)
Array< T, 1 > matrix_product(const Array< T, 2 > &matrix, const Array< T, 1 > &vector)
Definition: utils.h:42
recoDim
Definition: reco.h:74
STD_map< RecoCoord, UInt > CoordCountMap
Definition: index.h:364
STD_vector< T > list2vector(const STD_list< T > &src)
Definition: tjvector.h:500
RecoIndex index[n_recoDims]
Definition: index.h:269
RecoCoord & set_mode(RecoIndex::indexMode m, recoDim d1=n_recoDims, recoDim d2=n_recoDims, recoDim d3=n_recoDims, recoDim d4=n_recoDims, recoDim d5=n_recoDims, recoDim d6=n_recoDims, recoDim d7=n_recoDims, recoDim d8=n_recoDims)
Definition: index.h:238
STD_string print(RecoIndex::printMode printmode=RecoIndex::brief) const
RecoController & controller
Definition: odinreco/step.h:55