eXpress “1.5”

src/frequencymatrix.h

00001 #ifndef FREQUENCYMATRIX_H
00002 #define FREQUENCYMATRIX_H
00003 
00012 #include <cassert>
00013 #include <vector>
00014 #include "main.h"
00015 
00026 template <class T>
00027 class FrequencyMatrix {
00032   std::vector<T> _array;
00037   std::vector<T> _rowsums;
00041   size_t _M;
00045   size_t _N;
00049   bool _logged;
00054   bool _fixed;
00055 
00056 public:
00060   FrequencyMatrix(){};
00070   FrequencyMatrix(size_t m, size_t n, T alpha, bool logged = true);
00081   T operator()(size_t i, size_t j, bool normalized=true) const;
00091   T operator()(size_t k, bool normalized=true) const;
00099   void increment(size_t i, size_t j, T incr_amt);
00107   void increment(size_t k, T incr_amt);
00114   T sum(size_t i) const { return _rowsums[i]; }
00121   size_t argmax(size_t i) const;
00128   void set_logged(bool logged);
00134   void fix();
00139   bool is_fixed() const { return _fixed; }
00140 };
00141 
00142 template <class T>
00143 FrequencyMatrix<T>::FrequencyMatrix(size_t m, size_t n, T alpha, bool logged)
00144     : _array(m*n, logged ? log(alpha):alpha),
00145       _rowsums(m, logged ? log(n*alpha):n*alpha),
00146       _M(m),
00147       _N(n),
00148       _logged(logged),
00149       _fixed(false){
00150 }
00151 
00152 template <class T>
00153 T FrequencyMatrix<T>::operator()(size_t i, size_t j, bool normalized) const {
00154   assert(i*_N+j < _M*_N);
00155   if (_fixed || !normalized) {
00156       return _array[i*_N+j];
00157   }
00158   if (_logged) {
00159     return _array[i*_N+j]-_rowsums[i];
00160   } else {
00161     return _array[i*_N+j]/_rowsums[i];
00162   }
00163 }
00164 
00165 template <class T>
00166 T FrequencyMatrix<T>::operator()(size_t k, bool normalized) const {
00167   return operator()(0, k, normalized);
00168 }
00169 
00170 template <class T>
00171 void FrequencyMatrix<T>::increment(size_t i, size_t j, T incr_amt) {
00172   if (_fixed) {
00173     return;
00174   }
00175 
00176   assert(i < _M && j < _N);
00177   size_t k = i*_N+j;
00178   if (_logged) {
00179     _array[k] = log_add(_array[k], incr_amt);
00180     _rowsums[i] = log_add(_rowsums[i], incr_amt);
00181   } else {
00182     _array[k] += incr_amt;
00183     _rowsums[i] += incr_amt;
00184   }
00185 //    assert(!std::isnan(_rowsums[i]) && !std::isinf(_rowsums[i]));
00186 }
00187 
00188 template <class T>
00189 void FrequencyMatrix<T>::increment(size_t k, T incr_amt) {
00190   increment(0, k, incr_amt);
00191 }
00192 
00193 template <class T>
00194 void FrequencyMatrix<T>::set_logged(bool logged) {
00195   if (logged == _logged || _fixed) {
00196     return;
00197   }
00198   if (logged) {
00199     for (size_t i = 0; i < _M*_N; ++i) {
00200       _array[i] = log(_array[i]);
00201     }
00202     for(size_t i = 0; i < _M; ++i) {
00203       _rowsums[i] = log(_rowsums[i]);
00204     }
00205   } else {
00206     for (size_t i = 0; i < _M*_N; ++i) {
00207        _array[i] = sexp(_array[i]);
00208     }
00209     for (size_t i = 0; i < _M; ++i) {
00210       _rowsums[i] = sexp(_rowsums[i]);
00211     }
00212   }
00213   _logged = logged;
00214 }
00215 
00216 template <class T>
00217 size_t FrequencyMatrix<T>::argmax(size_t i) const {
00218   assert(i < _M);
00219   size_t k = i*_N;
00220   size_t arg = 0;
00221   T val = _array[k];
00222   for (size_t j = 1; j < _N; j++) {
00223     if (_array[k+j] > val) {
00224       val = _array[k+j];
00225       arg = j;
00226     }
00227   }
00228   return arg;
00229 }
00230 
00231 template <class T>
00232 void FrequencyMatrix<T>::fix() {
00233   if (_fixed) {
00234     return;
00235   }
00236   for (size_t i = 0; i < _M; ++i) {
00237     for (size_t j = 0; j < _N; ++j) {
00238       _array[i*_N+j] = operator()(i,j);
00239     }
00240     _rowsums[i] = (_logged) ? 0 : 1;
00241   }
00242   _fixed = true;
00243 }
00244 
00245 #endif
 All Classes Functions Variables