Search code examples
postgresqlhibernatejpavector-databasepgvector

What JPA + Hibernate data type should I use to support the vector extension in a PostgreSQL database?


What JPA + Hibernate data type should I use to support the vector extension in a PostgreSQL database, so that it allows me to create embeddings using a JPA Entity?

CREATE TABLE items (id bigserial PRIMARY KEY, embedding vector(3));

pgvector


Solution

  • You can use vladmihalcea Hibernate types to convert a vector type to List<Double>, so it is possible to save or query with JpaRepository.

    1. Add a dependency to the pom.xml file:

      <dependency>
        <groupId>io.hypersistence</groupId>
        <artifactId>hypersistence-utils-hibernate-55</artifactId>
        <version>3.5.0</version>
      </dependency>
      
    2. Create the Item class:

      import com.fasterxml.jackson.annotation.JsonInclude;
      import io.hypersistence.utils.hibernate.type.json.JsonType;
      import lombok.Data;
      import lombok.NoArgsConstructor;
      import org.hibernate.annotations.Type;
      import org.hibernate.annotations.TypeDef;
      
      import javax.persistence.*;
      import java.util.List;
      
      @Data
      @NoArgsConstructor
      @Entity
      @Table(name = "items")
      @JsonInclude(JsonInclude.Include.NON_NULL)
      @TypeDef(name = "json", typeClass = JsonType.class)
      public class Item {
          @Id
          @GeneratedValue(strategy = GenerationType.IDENTITY)
          private Long id;
      
          @Type(type = "json")
          @Column(columnDefinition = "vector")
          private List<Double> embedding;
      }
      
    3. Create a JpaRepository interface that supports save and find. You can write custom findNearestNeighbors methods with native SQL

      import org.springframework.data.jpa.repository.JpaRepository;
      
      public interface ItemRepository extends JpaRepository<Item, Long> {
      
          // Find nearest neighbors by a vector, for example value = "[1,2,3]"
          // This also works, cast is equals to the :: operator in postgresql
          //@Query(nativeQuery = true, value = "SELECT * FROM items ORDER BY embedding <-> cast(? as vector) LIMIT 5")
          @Query(nativeQuery = true, value = "SELECT * FROM items ORDER BY embedding <-> ? \\:\\:vector LIMIT 5")
          List<Item> findNearestNeighbors(String value);
      
          // Find nearest neighbors by a record in the same table
          @Query(nativeQuery = true, value = "SELECT * FROM items WHERE id != :id ORDER BY embedding <-> (SELECT embedding FROM items WHERE id = :id) LIMIT 5")
          List<Item> findNearestNeighbors(Long id);
      }
      
    4. Test create, query and findNearestNeighbors:

      @Autowired
      private ItemRepository itemRepository;
      
      @Test
      @Rollback(false)
      @Transactional
      public void createItem() {
          Item item = new Item();
          Random rand = new Random();
          List<Double> embedding = new ArrayList<>();
          for (int i = 0; i < 3; i++)
              embedding.add(rand.nextDouble());
          item.setEmbedding(embedding);
          itemRepository.save(item);
      }
      
      @Test
      public void loadItems() {
          final List<Item> items = itemRepository.findAll();
          System.out.println(items);
      }
      
      @Test
      public void findNearestNeighbors() {
          final String value = "[0.1, 0.2, 0.3]";
          final List<Item> items = itemRepository.findNearestNeighbors(value);
          System.out.println(items);
      }