eXpress “1.5”
|
00001 #ifndef FREQUENCYMATRIX_H 00002 #define FREQUENCYMATRIX_H 00003 00012 #include <cassert> 00013 #include <vector> 00014 #include "main.h" 00015 00026 template <class T> 00027 class FrequencyMatrix { 00032 std::vector<T> _array; 00037 std::vector<T> _rowsums; 00041 size_t _M; 00045 size_t _N; 00049 bool _logged; 00054 bool _fixed; 00055 00056 public: 00060 FrequencyMatrix(){}; 00070 FrequencyMatrix(size_t m, size_t n, T alpha, bool logged = true); 00081 T operator()(size_t i, size_t j, bool normalized=true) const; 00091 T operator()(size_t k, bool normalized=true) const; 00099 void increment(size_t i, size_t j, T incr_amt); 00107 void increment(size_t k, T incr_amt); 00114 T sum(size_t i) const { return _rowsums[i]; } 00121 size_t argmax(size_t i) const; 00128 void set_logged(bool logged); 00134 void fix(); 00139 bool is_fixed() const { return _fixed; } 00140 }; 00141 00142 template <class T> 00143 FrequencyMatrix<T>::FrequencyMatrix(size_t m, size_t n, T alpha, bool logged) 00144 : _array(m*n, logged ? log(alpha):alpha), 00145 _rowsums(m, logged ? log(n*alpha):n*alpha), 00146 _M(m), 00147 _N(n), 00148 _logged(logged), 00149 _fixed(false){ 00150 } 00151 00152 template <class T> 00153 T FrequencyMatrix<T>::operator()(size_t i, size_t j, bool normalized) const { 00154 assert(i*_N+j < _M*_N); 00155 if (_fixed || !normalized) { 00156 return _array[i*_N+j]; 00157 } 00158 if (_logged) { 00159 return _array[i*_N+j]-_rowsums[i]; 00160 } else { 00161 return _array[i*_N+j]/_rowsums[i]; 00162 } 00163 } 00164 00165 template <class T> 00166 T FrequencyMatrix<T>::operator()(size_t k, bool normalized) const { 00167 return operator()(0, k, normalized); 00168 } 00169 00170 template <class T> 00171 void FrequencyMatrix<T>::increment(size_t i, size_t j, T incr_amt) { 00172 if (_fixed) { 00173 return; 00174 } 00175 00176 assert(i < _M && j < _N); 00177 size_t k = i*_N+j; 00178 if (_logged) { 00179 _array[k] = log_add(_array[k], incr_amt); 00180 _rowsums[i] = log_add(_rowsums[i], incr_amt); 00181 } else { 00182 _array[k] += incr_amt; 00183 _rowsums[i] += incr_amt; 00184 } 00185 // assert(!std::isnan(_rowsums[i]) && !std::isinf(_rowsums[i])); 00186 } 00187 00188 template <class T> 00189 void FrequencyMatrix<T>::increment(size_t k, T incr_amt) { 00190 increment(0, k, incr_amt); 00191 } 00192 00193 template <class T> 00194 void FrequencyMatrix<T>::set_logged(bool logged) { 00195 if (logged == _logged || _fixed) { 00196 return; 00197 } 00198 if (logged) { 00199 for (size_t i = 0; i < _M*_N; ++i) { 00200 _array[i] = log(_array[i]); 00201 } 00202 for(size_t i = 0; i < _M; ++i) { 00203 _rowsums[i] = log(_rowsums[i]); 00204 } 00205 } else { 00206 for (size_t i = 0; i < _M*_N; ++i) { 00207 _array[i] = sexp(_array[i]); 00208 } 00209 for (size_t i = 0; i < _M; ++i) { 00210 _rowsums[i] = sexp(_rowsums[i]); 00211 } 00212 } 00213 _logged = logged; 00214 } 00215 00216 template <class T> 00217 size_t FrequencyMatrix<T>::argmax(size_t i) const { 00218 assert(i < _M); 00219 size_t k = i*_N; 00220 size_t arg = 0; 00221 T val = _array[k]; 00222 for (size_t j = 1; j < _N; j++) { 00223 if (_array[k+j] > val) { 00224 val = _array[k+j]; 00225 arg = j; 00226 } 00227 } 00228 return arg; 00229 } 00230 00231 template <class T> 00232 void FrequencyMatrix<T>::fix() { 00233 if (_fixed) { 00234 return; 00235 } 00236 for (size_t i = 0; i < _M; ++i) { 00237 for (size_t j = 0; j < _N; ++j) { 00238 _array[i*_N+j] = operator()(i,j); 00239 } 00240 _rowsums[i] = (_logged) ? 0 : 1; 00241 } 00242 _fixed = true; 00243 } 00244 00245 #endif