Search code examples
sqlpostgresqlcasesampling

A query to perform sampling, choosing every 3rd row based on condition using postgres


I have a table containing the image id and score.

Table

A positive sample is the image id with the score closer to 1 and a negative sample is the image id with score closer to 0.

In order to sample we order the image ids in decreasing order of scores and pick every 3rd image id, this constitutes a positive sample. In order to get the negative sample we order the image ids in increasing order of scores and pick every 3rd image id, this would constitute a negative sample.

Posting the schema for reference:

CREATE TABLE IF NOT EXISTS unlabeled_image_predictions ( image_id int, score float);

INSERT INTO unlabeled_image_predictions (image_id, score) VALUES

('828','0.3149'), ('705','0.9892'), ('46', '0.5616'), ('594', '0.7670'), ('232','0.1598'), ('524','0.9876'), ('306','0.6487'), ('132','0.8823'), ('906','0.8394'), ('272', '0.9778'), ('616', '0.1003'), ('161', '0.7113'), ('715', '0.8921'), ('109', '0.1151'), ('424','0.7790'), ('609', '0.5241'), ('63', '0.2552'), ('276','0.2672'), ('701','0.0758'), ('554','0.4418'), ('998', '0.0379'), ('809','0.1058'), ('219','0.7143'), ('402', '0.7655'), ('363', '0.2661'), ('624', '0.8270'), ('640','0.8790'), ('913','0.2421'), ('439','0.3387'), ('464', '0.3674'), ('405', '0.6929'), ('986', '0.8931'), ('344', '0.3761'), ('847', '0.4889'), ('482', '0.5023'), ('823','0.3361'), ('617','0.0218'), ('47', '0.0072'), ('867','0.4050'), ('96','0.4498'), ('126','0.3564'), ('943', '0.0452'), ('115','0.5309'), ('417', '0.7168'), ('706','0.9649'), ('166', '0.2507'), ('991', '0.4191'), ('465', '0.0895'), ('53', '0.8169'), ('971','0.9871');

This is what the expected output should look like:

Expected Output

I tried the following query

SELECT weak_label,
CASE WHEN weak_label = 1 THEN (SELECT json_agg(a.image_id) FROM ( SELECT *, row_number() OVER(ORDER BY score DESC) AS row FROM unlabeled_image_predictions ) a WHERE a.row % 3 = 0 ) ELSE (SELECT json_agg(c.image_id) FROM ( SELECT *, row_number() OVER(ORDER BY score ASC) AS row FROM unlabeled_image_predictions ) c WHERE c.row % 3 = 0 ) END AS label

FROM (
SELECT image_id, CASE WHEN score < 0.5100 THEN 0 ELSE 1 END AS weak_label FROM unlabeled_image_predictions) mod

It yields the following output

Schema (PostgreSQL v15)

CREATE TABLE IF NOT EXISTS unlabeled_image_predictions (
  image_id int,
  score float);

INSERT INTO unlabeled_image_predictions (image_id, score) VALUES

('828','0.3149'), ('705','0.9892'), ('46', '0.5616'), ('594', '0.7670'), ('232','0.1598'), ('524','0.9876'), ('306','0.6487'),
('132','0.8823'), ('906','0.8394'), ('272', '0.9778'), ('616', '0.1003'), ('161', '0.7113'), ('715', '0.8921'), ('109', '0.1151'),
('424','0.7790'), ('609', '0.5241'), ('63', '0.2552'), ('276','0.2672'), ('701','0.0758'), ('554','0.4418'), ('998', '0.0379'),
('809','0.1058'), ('219','0.7143'), ('402', '0.7655'), ('363', '0.2661'), ('624', '0.8270'), ('640','0.8790'), ('913','0.2421'),
('439','0.3387'), ('464', '0.3674'), ('405', '0.6929'), ('986', '0.8931'), ('344', '0.3761'), ('847', '0.4889'), ('482', '0.5023'),
('823','0.3361'), ('617','0.0218'), ('47', '0.0072'), ('867','0.4050'), ('96','0.4498'), ('126','0.3564'), ('943', '0.0452'),
('115','0.5309'), ('417', '0.7168'), ('706','0.9649'), ('166', '0.2507'), ('991', '0.4191'), ('465', '0.0895'), ('53', '0.8169'),
('971','0.9871');

Query #1

SELECT weak_label,  
CASE
    WHEN weak_label = 1 THEN (SELECT json_agg(a.image_id)
        FROM (
         SELECT *, row_number() OVER(ORDER BY score DESC) AS row
         FROM unlabeled_image_predictions
            ) a
WHERE a.row % 3 = 0
)
ELSE (SELECT json_agg(c.image_id)
        FROM (
         SELECT *, row_number() OVER(ORDER BY score ASC) AS row
         FROM unlabeled_image_predictions
            ) c
WHERE c.row % 3 = 0
)
END AS label

FROM (  
SELECT image_id,
 CASE 
    WHEN score < 0.5100 THEN 0 ELSE 1
END AS weak_label
FROM unlabeled_image_predictions) mod;
weak_label label
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998

View on DB Fiddle

However, this is not what the expected output should look like.

Posting the expected output for reference:

Expected Output


Solution

  • It is a little bit obscure how you would get that result. From your description:

    with negatives (image_id, score, weak_label, rowNo) as
             (SELECT image_id,
                     score,
                     0 AS weak_label,
                     row_number() over (order by score asc)
              FROM unlabeled_image_predictions
              where score < 0.5100),
         positives (image_id, score, weak_label, rowNo) as
             (SELECT image_id,
                     score,
                     1 AS weak_label,
                     row_number() over (order by score desc)
              FROM unlabeled_image_predictions
              where score >= 0.5100),
         combined as
             (select * from negatives
              union
              select * from positives
              )
    select image_id, weak_label
    from combined
    where rowNo % 3 = 1
    order by image_id;
    

    Would yield:

    image_id weak_label
    47 0
    63 0
    96 0
    115 1
    126 0
    232 0
    272 1
    405 1
    417 1
    424 1
    616 0
    705 1
    715 1
    828 0
    867 0
    906 1
    943 0