main.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from sklearn.cluster import DBSCAN
  2. from sklearn import metrics
  3. from collections import namedtuple
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from itertools import groupby
  7. from math import pi, sin, cos
  8. import csv_parser
  9. import recog
  10. import metric
  11. def cluster_statistic(xy):
  12. center = np.array([np.average(xy[:, 0]), np.average(xy[:, 1])])
  13. center_p = namedtuple("point", "lon lat")(center[0], center[1])
  14. rad = 0
  15. for loc in xy:
  16. p = namedtuple("point", "lon lat")(loc[0], loc[1])
  17. rad = max(rad, metric.spherical_distance(center_p, p))
  18. rad *= 6400 * 1000
  19. return (center, rad)
  20. if __name__ == '__main__':
  21. data = csv_parser.parse_data_from_csv('pitchtest0730.csv')
  22. groups = groupby(data, key = lambda x: x.hwid)
  23. entries = []
  24. for k, grp in groups:
  25. data1 = list(grp)
  26. data1.sort(key = lambda x: x.timestamp)
  27. entries1 = recog.recognize_entries(data1)
  28. for i in entries1:
  29. entries.append(i)
  30. for ent in entries:
  31. print(ent)
  32. x = []
  33. yaws = []
  34. for e in entries:
  35. x.append([e.lon, e.lat])
  36. yaws.append(e.yaw / 180 * pi)
  37. x = np.array(x)
  38. yaws = np.array(yaws)
  39. db = DBSCAN(eps = 5/6400000, min_samples = 20,
  40. metric = metric.spherical_distance).fit(x)
  41. labels = db.labels_
  42. core_samples_mask = np.zeros_like(db.labels_, dtype = bool)
  43. core_samples_mask[db.core_sample_indices_] = True
  44. n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
  45. n_noise_ = list(labels).count(-1)
  46. print('Estimated number of clusters: %d' % n_clusters_)
  47. print('Estimated number of noise points: %d' % n_noise_)
  48. if n_clusters_ == 0:
  49. print('can not get any clusters')
  50. plt.plot(x[:,0], x[:,1], 'o')
  51. plt.show()
  52. exit(0)
  53. print("Silhouette Coefficient: %0.3f"
  54. % metrics.silhouette_score(x, labels))
  55. unique_labels = set(labels)
  56. colors = [plt.cm.Spectral(each)
  57. for each in np.linspace(0, 1, len(unique_labels))]
  58. for k, col in zip(unique_labels, colors):
  59. if k == -1:
  60. # Black used for noise.
  61. col = [0, 0, 0, 1]
  62. class_member_mask = (labels == k)
  63. xy = x[class_member_mask & core_samples_mask]
  64. plt.plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(col),
  65. markeredgecolor='k', markersize=14)
  66. xy = x[class_member_mask & ~core_samples_mask]
  67. plt.plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(col),
  68. markeredgecolor='k', markersize=6)
  69. xy = x[class_member_mask]
  70. if k != -1:
  71. center, rad = cluster_statistic(xy)
  72. print("cluster %d:" % k)
  73. print("size =", len(xy))
  74. print("center = %f, %f" % (center[0], center[1]))
  75. print("radius = %f m" % rad)
  76. yaw = yaws[class_member_mask]
  77. db_yaw = DBSCAN(eps = 0.05, min_samples=30,
  78. metric = metric.ang_distance).fit(yaw.reshape(-1, 1))
  79. lbs = set(db_yaw.labels_)
  80. arrow_colors = [plt.cm.Wistia(each)
  81. for each in np.linspace(0, 1, len(lbs))]
  82. for l in lbs:
  83. if l != -1:
  84. mask = (db_yaw.labels_ == l)
  85. center, rad = cluster_statistic(xy[mask])
  86. print(" sub-cluster %d:" % l)
  87. print(" size =", np.sum(mask))
  88. print(" center = %f, %f" % (center[0], center[1]))
  89. print(" radius = %f m" % rad)
  90. print(" avg yaw =",
  91. np.average(np.fmod(yaw[mask], 2*pi)) / pi * 180)
  92. kwargs = {'width': 1e-5}
  93. for i in range(len(xy)):
  94. col = arrow_colors[db_yaw.labels_[i]]
  95. if db_yaw.labels_[i] == -1:
  96. continue
  97. col = [0, 0, 0, 1]
  98. '''
  99. plt.arrow(xy[i, 0], xy[i, 1],
  100. 5e-5 * cos(pi/2 - yaw[i]),
  101. 5e-5 * sin(pi/2 - yaw[i]),
  102. **dict(width=1e-7, color=col))
  103. '''
  104. plt.annotate("", xytext=(xy[i, 0], xy[i, 1]),
  105. xy=(xy[i, 0] + 2e-5 * sin(pi/2 - yaw[i]),
  106. xy[i, 1] + 2e-5 * cos(pi/2 - yaw[i])),
  107. arrowprops=dict(arrowstyle="->",
  108. color=col))
  109. plt.show()