Search code examples
javamachine-learningvowpalwabbit

vowpal wabbit java: get raw predictions


I am using Java API of vowpal wabbit to get predictions. I need raw prediction (same as -r output.txt) but I couldn't find any such method in VWMulticlassLearner class. I am using below arg to train my model in python via cmd -

vw -f model_filepath -c --cache_file cache_filepath -k --csoaa 40 -b 24 -q cd -q .... -q n: --ignore a --ignore x

and we are using below code in Java to get predictions -

VWLearners.create("-i ./data/train.model  -t --quiet"); // VWMulticlassLearner
VWLearners.create("-i ./data/train.model  -t --quiet --csoaa_ldf=mc --loss_function=logistic --probabilities"); //VWProbLearner

None of the classes has any method which returns raw prediction.

I want the same prediction as below -

$ echo ' .. sample string .. ' | vw -i data/train.model -t -r test -p /dev/stdout
creating quadratic features for pairs: cd ce cu cw de du dw eu ew uw n:
ignoring namespaces beginning with: a x
only testing
predictions = /dev/stdout
raw predictions = test
Num weight bits = 24
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile =
num sources = 1
average  since         example        example  current  current  current
loss     last          counter         weight    label  predict features
39
0.000000 0.000000            1            1.0    known       39      171

finished run
number of examples per pass = 1
passes used = 1
weighted example sum = 1.000000
weighted label sum = 0.000000
average loss = 0.000000
total feature number = 171

$ cat test
0:1.05645 1:0.83437 2:-0.210798 3:-2.81048 4:-4.47558 5:-4.45883 6:-3.65177 7:-3.71191 8:-2.96008 9:-2.82846 10:-2.31816 11:0.925984 12:3.28547 13:5.20375 14:6.34244 15:6.13525 16:1.65726 17:1.22801 18:1.35034 19:3.27091 20:2.94066 21:-0.0276409 22:0.391437 23:1.267 24:-0.689573 25:0.0171876 26:3.12935 27:3.95045 28:3.86978 29:1.18468 30:0.0921049 31:0.436564 32:0.98946 33:1.00963 34:-0.265355 35:-3.02128 36:-2.52846 37:-2.8066 38:-3.50639 39:-4.6184

How can I get values that are in file test in Java as a method response? I don't want to read the file to get a response in Java which will be slow.


Solution

  • I ended up using one of the abandoned PR. Here is my working git patch file -

    diff --git a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
    index 6b51c4d30..f3ccb6621 100644
    --- a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
    +++ b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
    @@ -11,3 +11,17 @@ JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predict(JNI
     JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predictMultiline(JNIEnv *env, jobject obj, jobjectArray example_strings, jboolean learn, jlong vwPtr)
     { return base_predict<jint>(env, example_strings, learn, vwPtr, multiclass_predictor);
     }
    +
    +jfloatArray multiclass_raw_predictor(example* vec, JNIEnv *env){
    +  size_t num_values = vec->l.cs.costs.size();
    +  jfloatArray j_labels = env->NewFloatArray(num_values);
    +  for (int i=0 ; i<num_values; i++) {
    +    jfloat f[] = { vec->l.cs.costs[i].partial_prediction };
    +    env->SetFloatArrayRegion(j_labels, i, 1, (float*)f);
    +   }
    +   return j_labels;
    + }
    +
    +JNIEXPORT jfloatArray JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_rawPredict(JNIEnv *env, jobject obj, jstring example_string, jboolean learn, jlong vwPtr){
    +return base_predict<jfloatArray>(env, example_string, learn, vwPtr, multiclass_raw_predictor);
    +}
    diff --git a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
    index 05204d53e..5610704fa 100644
    --- a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
    +++ b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
    @@ -24,6 +24,15 @@ JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predict
     JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predictMultiline
     (JNIEnv *, jobject, jobjectArray, jboolean, jlong);
     
    +/*
    + * Class:     vowpalWabbit_learner_VWMulticlassLearner
    + * Method:    rawPredict
    + * Signature: ([Ljava/lang/String;ZJ)I
    + */
    +JNIEXPORT jfloatArray JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_rawPredict
    +  (JNIEnv *, jobject, jstring, jboolean, jlong);
    +
    +
     #ifdef __cplusplus
     }
     #endif
    diff --git a/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java b/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
    index b506cfb25..bb3156351 100644
    --- a/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
    +++ b/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
    @@ -13,4 +13,25 @@ final public class VWMulticlassLearner extends VWIntLearner {
     
         @Override
         protected native int predictMultiline(String[] example, boolean learn, long nativePointer);
    +
    +    protected native float[] rawPredict(String example, boolean learn, long nativePointer);
    +
    +    /**
    +     * Get raw prediction output.
    +     *
    +     * @param example a single vw example string
    +     * @return Raw prediction
    +     */
    +
    +    public float[] rawPredict(final String example) {
    +        lock.lock();
    +        try {
    +            if (isOpen()) {
    +                return rawPredict(example, false, nativePointer);
    +            }
    +            throw new IllegalStateException("Already closed.");
    +        } finally {
    +            lock.unlock();
    +        }
    +    }
     }