Search code examples
javastringtriesuffix-tree

Suffix Trie matching, problem with matching operation


I am facing a problem with suffix Trie matching, I designed a suffix trie with a 26-way tree to represent characters in a node plus a value associated with each node. The value of each node denotes either the index where the string ( if it is a suffix ) starts in the main string or -1 otherwise. Thereafter I am trying to get matching operation to work but apparently it doesn't and I am not able to find out bugs in here. For any more clarification refer Second Question in this Pdf. Help, please.

import java.util.*;

class node{
    public int val;
    public node ptrs[];
    node(){
        this.val =0;
        ptrs = new node[26];
        for (node ptr : ptrs) {
            ptr = null;
        }
    }    
}
class Tree{
    public node root = new node();
    int pass =0;
    void insert(String s,int indx) {
        node trv = root;
        for (int i = 0; i < s.length(); i++) {
            if (trv.ptrs[s.charAt(i) - 'A'] == null) {
                trv.ptrs[s.charAt(i) - 'A'] = new node();
                if(i==s.length()-1){
                    trv.ptrs[s.charAt(i)-'A'].val = indx;
                }else{
                    trv.ptrs[s.charAt(i)-'A'].val = -1;
                }
            }
            trv = trv.ptrs[s.charAt(i) - 'A'];
        }
    }
    
    private void visit(node trv){
        for(int i =0;i<26;i++){
            if(trv.ptrs[i]!=null){
                System.out.println(trv.ptrs[i].val+":"+((char)(i+'A')));
                visit(trv.ptrs[i]);
            }
        }
    }

    void visit(){
        this.visit(root);
    }
    void leaf(node trv){
        if(trv.val>=0){
            System.out.println(trv.val);
        }else{
            for(int i=0;i<26;i++){
                if(trv.ptrs[i]!=null){
                    if(trv.ptrs[i].val>=0){
                        System.out.println(trv.ptrs[i].val);
                    }else{
                        leaf(trv.ptrs[i]);
                    }
                }
            }
        }
    }
    private void search(node trv,String s,int i){
        if(i<=s.length()-1){
            if(trv.ptrs[s.charAt(i)-'A']!=null){
                if(i==s.length()-1){
                    leaf(trv.ptrs[s.charAt(i)-'A']);
                }else{
                    search(trv.ptrs[s.charAt(i)-'A'], s, i+1);
                }
            }
        }
    }
    void query(String s){
        this.search(root, s, 0);
    }
}
public class Trie {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        Tree t = new Tree();
        String txt = sc.next();
        int txtLen = txt.length();
        for(int i =0;i<txtLen;i++){
            t.insert(txt.substring(i,txtLen),i);
        }
        int q = sc.nextInt();
        while(q-->0){
           String m = sc.next();
           t.query(m);
        }
        sc.close();
    }
}

Expected :

Input:
AATCGGGTTCAATCGGGGT
2
ATCG
GGGT
Output:
1 4 11 15

My Output :

AATCGGGTTCAATCGGGGT
2
ATCG
11
1
GGGT
4

I am not getting 15 as the answer as you can see.


Solution

  • I try to investigate your program but your program is not well-written and unfortunately I cannot find your problem. I try to print it by visit but there is no helpful information.

    But the following try to find pattern by suffix tree which is described at Fast Pattern Matching of Strings Using Suffix Tree. Maybe helpful:

    import java.util.ArrayList;
    import java.util.List;
    import java.util.stream.Collectors;
    
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    class Node {
        private String text;
        private List<Node> children;
        private int position;
    
        public Node(String word, int position) {
            this.text = word;
            this.position = position;
            this.children = new ArrayList<>();
        }
    
        public String getText() {
            return text;
        }
    
        public void setText(String text) {
            this.text = text;
        }
    
        public int getPosition() {
            return position;
        }
    
        public void setPosition(int position) {
            this.position = position;
        }
    
        public List<Node> getChildren() {
            return children;
        }
    
        public void setChildren(List<Node> children) {
            this.children = children;
        }
    
        public String printTree(String depthIndicator) {
            String str = "";
            String positionStr = position > -1 ? "[" + String.valueOf(position) + "]" : "";
            str += depthIndicator + text + positionStr + "\n";
    
            for (int i = 0; i < children.size(); i++) {
                str += children.get(i)
                        .printTree(depthIndicator + "\t");
            }
            return str;
        }
    
        @Override
        public String toString() {
            return printTree("");
        }
    }
    
    public class SuffixTree {
        private static final Logger LOGGER = LoggerFactory.getLogger(SuffixTree.class);
        private static final String WORD_TERMINATION = "$";
        private static final int POSITION_UNDEFINED = -1;
        private Node root;
        private String fullText;
    
        public SuffixTree(String text) {
            root = new Node("", POSITION_UNDEFINED);
            for (int i = 0; i < text.length(); i++) {
                addSuffix(text.substring(i) + WORD_TERMINATION, i);
            }
            fullText = text;
        }
    
        public List<String> searchText(String pattern) {
            LOGGER.info("Searching for pattern \"{}\"", pattern);
            List<String> result = new ArrayList<>();
            List<Node> nodes = getAllNodesInTraversePath(pattern, root, false);
    
            if (nodes.size() > 0) {
                Node lastNode = nodes.get(nodes.size() - 1);
                if (lastNode != null) {
                    List<Integer> positions = getPositions(lastNode);
                    positions = positions.stream()
                            .sorted()
                            .collect(Collectors.toList());
                    positions.forEach(m -> result.add((markPatternInText(m, pattern))));
                }
            }
            return result;
        }
    
        private void addSuffix(String suffix, int position) {
            LOGGER.info(">>>>>>>>>>>> Adding new suffix {}", suffix);
            List<Node> nodes = getAllNodesInTraversePath(suffix, root, true);
            if (nodes.size() == 0) {
                addChildNode(root, suffix, position);
                LOGGER.info("{}", printTree());
            } else {
                Node lastNode = nodes.remove(nodes.size() - 1);
                String newText = suffix;
                if (nodes.size() > 0) {
                    String existingSuffixUptoLastNode = nodes.stream()
                            .map(a -> a.getText())
                            .reduce("", String::concat);
    
                    // Remove prefix from newText already included in parent
                    newText = newText.substring(existingSuffixUptoLastNode.length());
                }
                extendNode(lastNode, newText, position);
                LOGGER.info("{}", printTree());
            }
        }
    
        private List<Integer> getPositions(Node node) {
            List<Integer> positions = new ArrayList<>();
            if (node.getText()
                    .endsWith(WORD_TERMINATION)) {
                positions.add(node.getPosition());
            }
            for (int i = 0; i < node.getChildren()
                    .size(); i++) {
                positions.addAll(getPositions(node.getChildren()
                        .get(i)));
            }
            return positions;
        }
    
        private String markPatternInText(Integer startPosition, String pattern) {
            String matchingTextLHS = fullText.substring(0, startPosition);
            String matchingText = fullText.substring(startPosition, startPosition + pattern.length());
            String matchingTextRHS = fullText.substring(startPosition + pattern.length());
            return matchingTextLHS + "[" + matchingText + "]" + matchingTextRHS;
        }
    
        private void addChildNode(Node parentNode, String text, int position) {
            parentNode.getChildren()
                    .add(new Node(text, position));
        }
    
        private void extendNode(Node node, String newText, int position) {
            String currentText = node.getText();
            String commonPrefix = getLongestCommonPrefix(currentText, newText);
    
            if (commonPrefix != currentText) {
                String parentText = currentText.substring(0, commonPrefix.length());
                String childText = currentText.substring(commonPrefix.length());
                splitNodeToParentAndChild(node, parentText, childText);
            }
    
            String remainingText = newText.substring(commonPrefix.length());
            addChildNode(node, remainingText, position);
        }
    
        private void splitNodeToParentAndChild(Node parentNode, String parentNewText, String childNewText) {
            Node childNode = new Node(childNewText, parentNode.getPosition());
    
            if (parentNode.getChildren()
                    .size() > 0) {
                while (parentNode.getChildren()
                        .size() > 0) {
                    childNode.getChildren()
                            .add(parentNode.getChildren()
                                    .remove(0));
                }
            }
    
            parentNode.getChildren()
                    .add(childNode);
            parentNode.setText(parentNewText);
            parentNode.setPosition(POSITION_UNDEFINED);
        }
    
        private String getLongestCommonPrefix(String str1, String str2) {
            int compareLength = Math.min(str1.length(), str2.length());
            for (int i = 0; i < compareLength; i++) {
                if (str1.charAt(i) != str2.charAt(i)) {
                    return str1.substring(0, i);
                }
            }
            return str1.substring(0, compareLength);
        }
    
        private List<Node> getAllNodesInTraversePath(String pattern, Node startNode, boolean isAllowPartialMatch) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < startNode.getChildren()
                    .size(); i++) {
                Node currentNode = startNode.getChildren()
                        .get(i);
                String nodeText = currentNode.getText();
                if (pattern.charAt(0) == nodeText.charAt(0)) {
                    if (isAllowPartialMatch && pattern.length() <= nodeText.length()) {
                        nodes.add(currentNode);
                        return nodes;
                    }
    
                    int compareLength = Math.min(nodeText.length(), pattern.length());
                    for (int j = 1; j < compareLength; j++) {
                        if (pattern.charAt(j) != nodeText.charAt(j)) {
                            if (isAllowPartialMatch) {
                                nodes.add(currentNode);
                            }
                            return nodes;
                        }
                    }
    
                    nodes.add(currentNode);
                    if (pattern.length() > compareLength) {
                        List<Node> nodes2 = getAllNodesInTraversePath(pattern.substring(compareLength), currentNode, isAllowPartialMatch);
                        if (nodes2.size() > 0) {
                            nodes.addAll(nodes2);
                        } else if (!isAllowPartialMatch) {
                            nodes.add(null);
                        }
                    }
                    return nodes;
                }
            }
            return nodes;
        }
    
        public String printTree() {
            return root.printTree("");
        }
    
        public static void main(String[] args) {
            SuffixTree suffixTree = new SuffixTree("AATCGGGTTCAATCGGGGT");
            List<String> matches = suffixTree.searchText("ATCG");
            matches.stream().forEach(m -> { System.out.println(m);});
            List<String> matches2 = suffixTree.searchText("GGGT");
            matches2.stream().forEach(m -> { System.out.println(m);});
        }
    }