Search code examples
pythonglove

Category detection


i have used this code for category detection..

import numpy as np

# Words -> category
categories = {word: key for key, words in data.items() for word in words}

# Load the whole embedding matrix
embeddings_index = {}
with open('glove.6B.100d.txt', encoding="utf8") as f:
  for line in f:
    values = line.split()
    word = values[0]
    embed = np.array(values[1:], dtype=np.float32)
    embeddings_index[word] = embed
print('Loaded %s word vectors.' % len(embeddings_index))
# Embeddings for available words
data_embeddings = {key: value for key, value in embeddings_index.items() if key in categories.keys()}

# Processing the query
def process(query):
  query_embed = embeddings_index[query]
  scores = {}
  for word, embed in data_embeddings.items():
    category = categories[word]
    dist = query_embed.dot(embed)
    dist /= len(data[category])
    scores[category] = scores.get(category, 0) + dist
  return scores


# Testing
print(process('pizza'))

OUTPUT

{'service': 6.385544379552205, 'ambiance': 3.5752111077308655, 'Food': 12.912149047851562}

is there a way I only get the highest accuracy category like Food??


Solution

  • def process(query):
      query_embed = embeddings_index[query]
      scores = {}
      for word, embed in data_embeddings.items():
        category = categories[word]
        dist = query_embed.dot(embed)
        dist /= len(data[category])
        scores[category] = scores.get(category, 0) + dist
      return max(scores, key=scores.get)
    

    You can use max() for this. This will return the key name of maximum value.