eXpress “1.5”

src/mismatchmodel.cpp

00001 //
00002 //  mismatchmodel.cpp
00003 //  express
00004 //
00005 //  Created by Adam Roberts on 4/23/11.
00006 //  Copyright 2011 Adam Roberts. All rights reserved.
00007 //
00008 
00009 #include "main.h"
00010 #include "mismatchmodel.h"
00011 #include "targets.h"
00012 #include "fragments.h"
00013 #include "sequence.h"
00014 #include <iostream>
00015 #include <fstream>
00016 
00017 using namespace std;
00018 
00019 MismatchTable::MismatchTable(double alpha)
00020     : _first_read_mm(max_read_len, FrequencyMatrix<double>(16, 4, alpha)),
00021       _second_read_mm(max_read_len, FrequencyMatrix<double>(16, 4, alpha)),
00022       _insert_params(1, max_indel_size + 1, 0),
00023       _delete_params(1, max_indel_size + 1, 0),
00024       _max_len(0),
00025       _active(false){
00026   // Set indel priors
00027   double no_indel_p = 0.99;
00028   double pm = no_indel_p;
00029   for(size_t i = 0 ; i <= max_indel_size; ++i) {
00030     _insert_params.increment(i, log(alpha * pm));
00031     _delete_params.increment(i, log(alpha * pm));
00032     pm *= (1 - no_indel_p);
00033   }
00034   assert(approx_eq(sexp(_insert_params.sum(0)), alpha));
00035   assert(approx_eq(sexp(_delete_params.sum(0)), alpha));
00036 }
00037 
00038 MismatchTable::MismatchTable(string param_file_name)
00039     : _first_read_mm(max_read_len, FrequencyMatrix<double>(16, 4, 0)),
00040       _second_read_mm(max_read_len, FrequencyMatrix<double>(16, 4, 0)),
00041       _insert_params(1, max_indel_size + 1, 0),
00042       _delete_params(1, max_indel_size + 1, 0),
00043       _max_len(0),
00044       _active(true){
00045   ifstream infile (param_file_name.c_str());
00046   const size_t BUFF_SIZE = 99999;
00047   char line_buff[BUFF_SIZE];
00048   if (!infile.is_open()) {
00049     logger.severe("Unable to open parameter file '%s'.",
00050                   param_file_name.c_str());
00051   }
00052 
00053   infile.getline (line_buff, BUFF_SIZE, '\n');
00054   size_t pos = 0;
00055   
00056   while (infile.good()) {
00057     infile.getline (line_buff, BUFF_SIZE, '\n');
00058     if (!strcmp(line_buff, ">First Read Mismatch")) {
00059       break;
00060     }
00061   }
00062   
00063   infile.getline (line_buff, BUFF_SIZE, '\n');
00064   while (infile.good()) {
00065     infile.getline (line_buff, BUFF_SIZE, '\n');
00066     if (!strcmp(line_buff, ">Second Read Mismatch")) {
00067       break;
00068     }
00069     
00070     if (pos >= max_read_len) {
00071       pos++;
00072       continue;
00073     }
00074     
00075     _first_read_mm[pos] = FrequencyMatrix<double>(16, 4, 0);
00076     char *p = strtok(line_buff, "\t");
00077     for (size_t i = 0; i < 16; ++i) {
00078       for(size_t j = 0; j < 4; ++j) {
00079         p = strtok(NULL, "\t");
00080         _first_read_mm[pos].increment(i, j, log(strtod(p,NULL)));
00081       }
00082     }
00083     pos++;
00084   }
00085   _max_len = min(pos, max_read_len);
00086   if (pos >= max_read_len) {
00087     logger.warn("First read error distribution of %d bases in '%s' truncated "
00088                 "after %d bases.",
00089                 pos-1, param_file_name.c_str(), max_read_len);
00090   }
00091   
00092   pos = 0;
00093   infile.getline (line_buff, BUFF_SIZE, '\n');
00094   
00095   while ( infile.good() ) {
00096     infile.getline (line_buff, BUFF_SIZE, '\n');
00097     
00098     if (!strncmp(line_buff, ">Insertion Length", 17)) {
00099       break;
00100     }
00101     
00102     if (pos >= max_read_len) {
00103       pos++;
00104       continue;
00105     }
00106     
00107     _second_read_mm[pos] = FrequencyMatrix<double>(16, 4, 0);
00108     char *p = strtok(line_buff, "\t");
00109     for (size_t i = 0; i < 16; ++i) {
00110       for(size_t j = 0; j < 4; ++j) {
00111         p = strtok(NULL, "\t");
00112         _second_read_mm[pos].increment(i, j, log(strtod(p,NULL)));
00113       }
00114     }
00115     pos++;
00116   }
00117   _max_len = max(_max_len, min(pos, max_read_len));
00118   if (pos >= max_read_len) {
00119     logger.warn("Second read error distribution of %d bases in '%s' truncated "
00120                 "after %d bases.",
00121                 pos-1, param_file_name.c_str(), max_read_len);
00122   }
00123   
00124   infile.getline (line_buff, BUFF_SIZE, '\n');
00125   char *p = strtok(line_buff, "\t");
00126   size_t k = 0;
00127   do {
00128     if (k > max_indel_size) {
00129       logger.warn("Paramater file '%s' insertion distribution is being "
00130                   "truncated at max indel length of %d.",
00131                   param_file_name.c_str(), max_indel_size);
00132       break;
00133     }
00134     _insert_params.increment(k, log(strtod(p,NULL)));
00135     p = strtok(NULL, "\t");
00136     k++;
00137 
00138   } while (p);
00139 
00140   infile.getline (line_buff, BUFF_SIZE, '\n');
00141   infile.getline (line_buff, BUFF_SIZE, '\n');
00142   p = strtok(line_buff, "\t");
00143   k = 0;
00144   do {
00145     if (k > max_indel_size) {
00146       logger.warn("Paramater file '%s' deletion distribution is being "
00147                   "truncated at max indel length of %d.",
00148                   param_file_name.c_str(), max_indel_size);
00149       break;
00150     }
00151     _delete_params.increment(k, log(strtod(p,NULL)));
00152     p = strtok(NULL, "\t");
00153     k++;
00154   } while (p);
00155   
00156   fix();
00157 }
00158 
00159 void MismatchTable::get_indices(const FragHit& f,
00160                            vector<char>& left_indices,
00161                            vector<char>& left_seq,
00162                            vector<char>& left_ref,
00163                            vector<char>& right_indices,
00164                            vector<char>& right_seq,
00165                            vector<char>& right_ref) const {
00166 
00167   const Target& targ = *f.target();
00168   const Sequence& t_seq_fwd = targ.seq(0);
00169   const Sequence& t_seq_rev = targ.seq(1);
00170   
00171   if (f.left_read()) {
00172     const ReadHit& read_l = *f.left_read();
00173     
00174     left_indices = vector<char>();
00175     left_seq = vector<char>();
00176     
00177     size_t i = 0;  // read index
00178     size_t j = read_l.left;  // genomic index
00179     
00180     vector<Indel>::const_iterator ins = read_l.inserts.begin();
00181     vector<Indel>::const_iterator del = read_l.deletes.begin();
00182     
00183     if (read_l.inserts.size() || read_l.deletes.size()) {
00184       logger.severe("Indels are not currently supported for eXpress-D.");
00185     }
00186     
00187     size_t cur_seq_bit = 0;
00188     
00189     while (i < read_l.seq.length()) {
00190       if (del != read_l.deletes.end() && del->pos == i) {
00191         j += del->len;
00192         del++;
00193       } else if (ins != read_l.inserts.end() && ins->pos == i) {
00194         i += ins->len;
00195         ins++;
00196       } else {
00197         size_t cur = read_l.seq[i];
00198         size_t ref = t_seq_fwd[j];
00199         if (cur != ref) {
00200           left_indices.push_back(i);
00201           if (cur_seq_bit / 8 == left_seq.size()) {
00202             left_seq.push_back(0);
00203             left_ref.push_back(0);
00204           }
00205           left_seq.back() += cur << (cur_seq_bit % 8);
00206           left_ref.back() += ref << (cur_seq_bit % 8);
00207           cur_seq_bit += 2;
00208         }
00209         
00210         i++;
00211         j++;
00212       }
00213     }
00214   }
00215   
00216   if (f.right_read()) {
00217     const ReadHit& read_r = *f.right_read();
00218     
00219     right_indices = vector<char>();
00220     right_seq = vector<char>();
00221 
00222     size_t r_len = read_r.seq.length();
00223     size_t i = 0;
00224     size_t j = targ.length() - read_r.right;
00225         
00226     vector<Indel>::const_iterator ins = read_r.inserts.end()-1;
00227     vector<Indel>::const_iterator del = read_r.deletes.end()-1;
00228 
00229     if (read_r.inserts.size() || read_r.deletes.size()) {
00230       logger.severe("Indels are not currently supported for eXpress-D.");
00231     }
00232     
00233     size_t cur_seq_bit = 0;
00234     while (i < r_len) {
00235       if (del != read_r.deletes.begin() - 1 && del->pos == r_len - i) {
00236         j += del->len;
00237         del--;
00238       } else if (ins != read_r.inserts.begin() - 1
00239                  && ins->pos + ins->len == r_len - i) {
00240         i += ins->len;
00241         ins--;
00242       } else {
00243         size_t cur = read_r.seq[i];
00244         size_t ref = t_seq_rev[j];
00245         
00246         if (cur != ref) {
00247           right_indices.push_back(i);
00248           if (cur_seq_bit / 8 == right_seq.size()) {
00249             right_seq.push_back(0);
00250             right_ref.push_back(0);
00251           }
00252           right_seq.back() += cur << (cur_seq_bit % 8);
00253           right_ref.back() += ref << (cur_seq_bit % 8);
00254           cur_seq_bit += 2;
00255         }
00256         
00257         i++;
00258         j++;
00259       }
00260     }
00261   }
00262 }
00263 
00264 double MismatchTable::log_likelihood(const FragHit& f) const {
00265   if (!_active) {
00266     return 0;
00267   }
00268   
00269   const Target& targ = *f.target();
00270   const Sequence& t_seq_fwd = targ.seq(0);
00271   const Sequence& t_seq_rev = targ.seq(1);
00272 
00273   double ll = 0;
00274 
00275   if (f.left_read()) {
00276     const ReadHit& read_l = *f.left_read();
00277     const vector<FrequencyMatrix<double> >& left_mm = (read_l.first) ?
00278                                                       _first_read_mm :
00279                                                       _second_read_mm;
00280     size_t i = 0;  // read index
00281     size_t j = read_l.left;  // genomic index
00282 
00283     bool insertion = false;
00284     bool deletion = false;
00285     
00286     vector<Indel>::const_iterator ins = read_l.inserts.begin();
00287     vector<Indel>::const_iterator del = read_l.deletes.begin();
00288     
00289     while (i < read_l.seq.length()) {
00290 
00291       if (del != read_l.deletes.end() && del->pos == i) {
00292         // Deletion at this position
00293         ll += _delete_params(del->len);
00294         j += del->len;
00295         del++;
00296         deletion = true;
00297       } else if (ins != read_l.inserts.end() && ins->pos == i) {
00298         // Insertion at this position
00299         ll += _insert_params(ins->len);
00300         i += ins->len;
00301         ins++;
00302         insertion = true;
00303       } else {
00304         ll += !insertion * _insert_params(0);
00305         ll += !deletion * _delete_params(0);
00306         insertion = false;
00307         deletion = false;
00308         
00309         size_t cur = read_l.seq[i];
00310         size_t prev = (i) ? (read_l.seq[i-1] << 2) : 0;
00311 
00312         if (t_seq_fwd.prob()) {
00313           double trans_prob = LOG_0;
00314           for (size_t nuc = 0; nuc < NUM_NUCS; nuc++) {
00315             size_t index = ((prev + nuc) << 2) + cur;
00316             
00317             trans_prob = log_add(trans_prob, t_seq_fwd.get_prob(j, nuc) +
00318                                              left_mm[i](index));
00319           }
00320           ll += trans_prob;
00321         } else {
00322           size_t ref = t_seq_fwd[j];
00323           size_t index = prev + ref;
00324           ll += left_mm[i](index, cur);
00325         }
00326         i++;
00327         j++;
00328       }
00329     }
00330   }
00331   
00332   if (f.right_read()) {
00333     const ReadHit& read_r = *f.right_read();
00334     
00335     const vector<FrequencyMatrix<double> >& right_mm = (read_r.first) ?
00336                                                         _first_read_mm :
00337                                                         _second_read_mm;
00338     
00339     size_t r_len = read_r.seq.length();
00340     size_t i = 0;
00341     size_t j = targ.length() - read_r.right;
00342 
00343     bool insertion = false;
00344     bool deletion = false;
00345     
00346     vector<Indel>::const_iterator ins = read_r.inserts.end()-1;
00347     vector<Indel>::const_iterator del = read_r.deletes.end()-1;
00348 
00349     while (i < r_len) {
00350       if (del != read_r.deletes.begin() - 1 && del->pos == r_len - i) {
00351         ll += _delete_params(del->len);
00352         j += del->len;
00353         del--;
00354         deletion = true;
00355       } else if (ins != read_r.inserts.begin() - 1
00356                  && ins->pos + ins->len == r_len - i) {
00357         ll += _insert_params(ins->len);
00358         i += ins->len;
00359         ins--;
00360         insertion = true;
00361       } else {
00362         ll += !insertion * _insert_params(0);
00363         ll += !deletion * _delete_params(0);
00364         insertion = false;
00365         deletion = false;
00366         
00367         size_t cur = read_r.seq[i];
00368         size_t prev = (i) ? (read_r.seq[i-1] << 2) : 0;
00369 
00370         if (t_seq_rev.prob()) {
00371           double trans_prob = LOG_0;
00372           for (size_t nuc = 0; nuc < NUM_NUCS; nuc++) {
00373             size_t index = ((prev + nuc) << 2) + cur;
00374             trans_prob = log_add(trans_prob, t_seq_rev.get_prob(j, nuc) +
00375                                              right_mm[i](index));
00376           }
00377           ll += trans_prob;
00378         } else {
00379           size_t ref = t_seq_rev[j];
00380           size_t index = prev + ref;
00381           ll += right_mm[i](index, cur);
00382         }
00383         i++;
00384         j++;
00385       }
00386     }
00387   }
00388   
00389   assert(!(isnan(ll)||isinf(ll)));
00390   return ll;
00391 }
00392 
00393 void MismatchTable::update(const FragHit& f, double p, double mass) {
00394   if (mass == LOG_0) {
00395     return;
00396   }
00397 
00398   Target& targ = *f.target();
00399   Sequence& t_seq_fwd = targ.seq(0);
00400   Sequence& t_seq_rev = targ.seq(1);
00401 
00402   if (f.left_read()) {
00403     const ReadHit& read_l = *f.left_read();
00404     vector<FrequencyMatrix<double> >& left_mm = (read_l.first) ?
00405                                                 _first_read_mm :
00406                                                 _second_read_mm;
00407     size_t i = 0;  // read index
00408     size_t j = read_l.left;  // genomic index
00409 
00410     bool insertion = false;
00411     bool deletion = false;
00412     
00413     vector<Indel>::const_iterator ins = read_l.inserts.begin();
00414     vector<Indel>::const_iterator del = read_l.deletes.begin();
00415 
00416     vector<double> joint_probs(NUM_NUCS);
00417 
00418     assert(targ.length() >= f.right());
00419     while (i < read_l.seq.length()) {
00420       if (del != read_l.deletes.end() && del->pos == i) {
00421         _delete_params.increment(del->len, mass);
00422         j += del->len;
00423         del++;
00424         deletion = true;
00425       } else if (ins != read_l.inserts.end() && ins->pos == i) {
00426         _insert_params.increment(ins->len, mass);
00427         i += ins->len;
00428         ins++;
00429         insertion = true;
00430       } else {
00431         if (!insertion) {
00432           _insert_params.increment(0, mass);
00433         }
00434         if (!deletion) {
00435           _delete_params.increment(0, mass);
00436         }
00437         insertion = false;
00438         deletion = false;
00439         
00440         size_t cur = read_l.seq[i];
00441         size_t prev = (i) ? (read_l.seq[i-1] << 2) : 0;
00442         // Update the seq parameters only after burn-in (active)
00443         if (t_seq_fwd.prob() && _active) {
00444           
00445           t_seq_fwd.update_obs(j, cur, p);
00446 
00447           double Z = LOG_0;
00448 
00449           size_t ref_index = (i) ? (t_seq_fwd.get_ref(j-1)<<2) +
00450                                     t_seq_fwd.get_ref(j) : t_seq_fwd.get_ref(j);
00451           for (size_t nuc = 0; nuc < NUM_NUCS; nuc++) {
00452             // Update expected
00453             t_seq_fwd.update_exp(j, nuc, p+left_mm[i](ref_index, nuc));
00454 
00455             // Update posterior
00456             size_t index = prev + nuc;
00457             joint_probs[nuc] = t_seq_fwd.get_prob(j, nuc) +
00458                                left_mm[i](index, cur);
00459             Z = log_add(Z, joint_probs[nuc]);
00460           }
00461 
00462           for (size_t nuc = 0; !left_mm[i].is_fixed() && nuc < NUM_NUCS; nuc++) {
00463             size_t index = prev + nuc;
00464             left_mm[i].increment(index, cur,
00465                                  mass + p + t_seq_fwd.get_prob(j, nuc));
00466           }
00467 
00468           for (size_t nuc=0; nuc < NUM_NUCS; nuc++) {
00469             t_seq_fwd.update_est(j, nuc, p + joint_probs[nuc] - Z);
00470           }
00471         } else {
00472           size_t ref = t_seq_fwd[j];
00473           size_t index = prev + ref;
00474           left_mm[i].increment(index, cur, mass + p);
00475         }
00476 
00477         i++;
00478         j++;
00479       }
00480     }
00481     _max_len = max(_max_len, read_l.seq.length());
00482   }
00483   
00484   if (f.right_read()) {
00485     const ReadHit& read_r = *f.right_read();
00486     vector<FrequencyMatrix<double> >& right_mm = (read_r.first) ?
00487                                                   _first_read_mm :
00488                                                   _second_read_mm;
00489     
00490     size_t r_len = read_r.seq.length();
00491     size_t i = 0;
00492     size_t j = targ.length() - read_r.right;
00493 
00494     bool insertion = false;
00495     bool deletion = false;
00496     
00497     vector<Indel>::const_iterator ins = read_r.inserts.end() - 1;
00498     vector<Indel>::const_iterator del = read_r.deletes.end() - 1;
00499 
00500     vector<double> joint_probs(NUM_NUCS);
00501 
00502     while (i < r_len) {
00503       if (del != read_r.deletes.begin()-1 && del->pos == r_len-i ) {
00504         _delete_params.increment(del->len, mass);
00505         j += del->len;
00506         del--;
00507         deletion = true;
00508       } else if (ins != read_r.inserts.begin() - 1 &&
00509                  ins->pos + ins->len == r_len-i) {
00510         _insert_params.increment(ins->len, mass);
00511         i += ins->len;
00512         ins--;
00513         insertion = true;
00514       } else {
00515         if (!insertion) {
00516           _delete_params.increment(0, mass);
00517         }
00518         if (!deletion) {
00519           _insert_params.increment(0, mass);
00520         }
00521         insertion = false;
00522         deletion = false;
00523 
00524         size_t cur = read_r.seq[i];
00525         size_t prev = (i) ? (read_r.seq[i-1] << 2) : 0;
00526 
00527         if (t_seq_rev.prob() && _active) {          
00528           t_seq_rev.update_obs(j, cur, p);
00529 
00530           double Z = LOG_0;
00531 
00532           size_t ref_index = (i) ? (t_seq_rev.get_ref(j-1)<<2) +
00533                                    t_seq_rev.get_ref(j) : t_seq_rev.get_ref(j);
00534           for (size_t nuc = 0; nuc < NUM_NUCS; nuc++) {
00535             // Update expected
00536             t_seq_rev.update_exp(j, nuc, p+right_mm[i](ref_index, nuc));
00537 
00538             // Update posterior
00539             size_t index = prev + nuc;
00540             joint_probs[nuc] = t_seq_rev.get_prob(j, nuc) +
00541                                right_mm[i](index, cur);
00542             Z = log_add(Z, joint_probs[nuc]);
00543           }
00544 
00545           for (size_t nuc = 0; !right_mm[i].is_fixed() && nuc < NUM_NUCS; nuc++) {
00546             size_t index = prev + nuc;
00547             right_mm[i].increment(index, cur, mass+p+t_seq_rev.get_prob(j, nuc));
00548           }
00549 
00550           for (size_t nuc=0; nuc < NUM_NUCS; nuc++) {
00551             t_seq_rev.update_est(j, nuc, p + joint_probs[nuc] - Z);
00552           }
00553         } else {
00554           size_t ref = t_seq_rev[j];
00555           size_t index = prev + ref;
00556           right_mm[i].increment(index, cur, mass+p);
00557         }
00558 
00559         i++;
00560         j++;
00561       }
00562       _max_len = max(_max_len, read_r.seq.length());
00563     }
00564   }
00565 }
00566 
00567 void MismatchTable::fix() {
00568   for (size_t i = 0; i < max_read_len; i++) {
00569     _first_read_mm[i].fix();
00570     _second_read_mm[i].fix();
00571   }
00572   _insert_params.fix();
00573   _delete_params.fix();
00574 }
00575 
00576 void MismatchTable::append_output(ofstream& outfile) const {
00577   string col_header =  "\t";
00578   for (size_t i = 0; i < 64; i++) {
00579     col_header += NUCS[i>>4];
00580     col_header += NUCS[i>>2 & 3];
00581     col_header += "->*";
00582     col_header += NUCS[i & 3];
00583     col_header += '\t';
00584   }
00585   col_header[col_header.length()-1] = '\n';
00586 
00587   outfile << ">First Read Mismatch\n" << col_header;
00588   for (size_t k = 0; k < _max_len; k++) {
00589     outfile << k+1 << ":\t";
00590     for (size_t i = 0; i < 16; i++) {
00591       for (size_t j = 0; j < 4; j++) {
00592         if (k || i < 4) {
00593           outfile << scientific << sexp(_first_read_mm[k](i,j))<<"\t";
00594         } else {
00595           outfile << scientific << 0.0 << "\t";
00596         }
00597       }
00598     }
00599     outfile<<endl;
00600   }
00601   outfile<<">Second Read Mismatch\n" << col_header;
00602   for (size_t k = 0; k < _max_len; k++) {
00603     outfile << k+1 << ":\t";
00604     for (size_t i = 0; i < 16; i++) {
00605       for (size_t j = 0; j < 4; j++) {
00606         if (k || i < 4) {
00607           outfile << scientific << sexp(_second_read_mm[k](i,j))<<"\t";
00608         } else {
00609           outfile << scientific << 0.0 << "\t";
00610         }
00611       }
00612     }
00613     outfile<<endl;
00614   }
00615   outfile << ">Insertion Length (0-" << max_indel_size << ")\n";
00616   for (size_t i = 0; i <= max_indel_size; i++) {
00617     outfile << scientific << sexp(_insert_params(i))<<"\t";
00618   }
00619   outfile<<endl;
00620   outfile << ">Deletion Length (0-" << max_indel_size << ")\n";
00621   for (size_t i = 0; i <= max_indel_size; i++) {
00622     outfile << scientific << sexp(_delete_params(i))<<"\t";
00623   }
00624   outfile<<endl;
00625 }
 All Classes Functions Variables