Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
menu search
person
Welcome To Ask or Share your Answers For Others

Categories

In one function I very often need to use code like:

which(x==1)[1]
which(x>1)[1]
x[x>10][1]

where x is a numeric vector. summaryRprof() shows that I spend >80% of the time on relational operators. I wonder if there is a function that does comparison only till the first TRUE value is reached to speed up my code. For-loop is slower than the options provided above.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
378 views
Welcome To Ask or Share your Answers For Others

1 Answer

I don't know of a pure R way to do this, so I wrote a C function to do it for the quantstrat package. This function was written with a specific purpose in mind, so it's not as general as I would like. For example, you may notice that it only works on real/double/numeric data, so be sure to coerce Data to that before calling the .firstCross function.

#include <R.h>
#include <Rinternals.h>

SEXP firstCross(SEXP x, SEXP th, SEXP rel, SEXP start)
{
    int i, int_rel, int_start;
    double *real_x=NULL, real_th;

    if(ncols(x) > 1)
        error("only univariate data allowed");

    /* this currently only works for real x and th arguments
     * support for other types may be added later */
    real_th = asReal(th);
    int_rel = asInteger(rel);
    int_start = asInteger(start)-1;

    switch(int_rel) {
        case 1:  /* >  */
            real_x = REAL(x);
            for(i=int_start; i<nrows(x); i++)
                if(real_x[i] >  real_th)
                    return(ScalarInteger(i+1));
            break;
        case 2:  /* <  */
            real_x = REAL(x);
            for(i=int_start; i<nrows(x); i++)
                if(real_x[i] <  real_th)
                    return(ScalarInteger(i+1));
            break;
        case 3:  /* == */
            real_x = REAL(x);
            for(i=int_start; i<nrows(x); i++)
                if(real_x[i] == real_th)
                    return(ScalarInteger(i+1));
            break;
        case 4:  /* >= */
            real_x = REAL(x);
            for(i=int_start; i<nrows(x); i++)
                if(real_x[i] >= real_th)
                    return(ScalarInteger(i+1));
            break;
        case 5:  /* <= */
            real_x = REAL(x);
            for(i=int_start; i<nrows(x); i++)
                if(real_x[i] <= real_th)
                    return(ScalarInteger(i+1));
            break;
        default:
            error("unsupported relationship operator");
  }
  /* return number of observations if relationship is never TRUE */
  return(ScalarInteger(nrows(x)));
}

And here's the R function that calls it:

.firstCross <- function(Data, threshold=0, relationship, start=1) {
    rel <- switch(relationship[1],
            '>'    =  ,
            'gt'   = 1,
            '<'    =  ,
            'lt'   = 2,
            '=='   =  ,
            'eq'   = 3,
            '>='   =  ,
            'gte'  =  ,
            'gteq' =  ,
            'ge'   = 4,
            '<='   =  ,
            'lte'  =  ,
            'lteq' =  ,
            'le'   = 5)
    .Call('firstCross', Data, threshold, rel, start)
}

Some benchmarks, just for fun.

> library(quantstrat)
> library(microbenchmark)
> firstCross <- quantstrat:::.firstCross
> set.seed(21)
> x <- rnorm(1e6)
> microbenchmark(which(x > 3)[1], firstCross(x,3,">"), times=10)
Unit: microseconds
                  expr      min       lq    median       uq      max neval
       which(x > 3)[1] 9482.081 9578.072 9597.3870 9690.448 9820.176    10
 firstCross(x, 3, ">")   11.370   11.675   31.9135   34.443   38.614    10
> which(x>3)[1]
[1] 919
> firstCross(x,3,">")
[1] 919

Note that firstCross will yield a larger relative speedup the larger Data is (because R's relational operators have to finish comparing the entire vector).

> x <- rnorm(1e7)
> microbenchmark(which(x > 3)[1], firstCross(x,3,">"), times=10)
Unit: microseconds
                  expr      min        lq    median        uq        max neval
       which(x > 3)[1] 94536.21 94851.944 95799.857 96154.756 113962.794    10
 firstCross(x, 3, ">")     5.08     5.507    25.845    32.164     34.183    10
> which(x>3)[1]
[1] 97
> firstCross(x,3,">")
[1] 97

...and it won't be appreciably faster if the first TRUE value is near the end of the vector.

> microbenchmark(which(x==last(x))[1], firstCross(x,last(x),"eq"),times=10)
Unit: milliseconds
                         expr      min       lq   median       uq       max neval
       which(x == last(x))[1] 92.56311 93.85415 94.38338 98.18422 106.35253    10
 firstCross(x, last(x), "eq") 86.55415 86.70980 86.98269 88.32168  92.97403    10
> which(x==last(x))[1]
[1] 10000000
> firstCross(x,last(x),"eq")
[1] 10000000

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
...