Coverage for src/plotly_gtk/_chart.py: 9%

245 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-08 21:22 +0000

1"""Contains a private class to handle plotting for 

2:class:`plotly_gtk.chart.PlotlyGTK`.""" 

3 

4import gi 

5import numpy as np 

6 

7from plotly_gtk.utils import * # pylint: disable=wildcard-import,unused-wildcard-import 

8 

9gi.require_version("Gtk", "4.0") 

10from gi.repository import ( # pylint: disable=wrong-import-position,wrong-import-order 

11 Gtk, 

12 Pango, 

13 PangoCairo, 

14) 

15 

16DEBUG = False 

17 

18 

19class _PlotlyGtk(Gtk.DrawingArea): 

20 def __init__(self, fig: dict): 

21 super().__init__() 

22 self.data = fig["data"] 

23 self.layout = fig["layout"] 

24 

25 self.set_draw_func(self._on_draw) 

26 

27 def update(self, fig: dict[str, plotly_types.Data, plotly_types.Layout]): 

28 """Update the plot with a new figure. 

29 

30 Parameters 

31 ---------- 

32 fig: dict[str, plotly_types.Data | plotly_types.Layout] 

33 A dictionary representing a plotly figure 

34 """ 

35 self.data = fig["data"] 

36 self.layout = fig["layout"] 

37 self.queue_draw() 

38 

39 def _on_draw(self, area, context, x, y): # pylint: disable=unused-argument 

40 self.get_parent().automargin() 

41 

42 width = area.get_size(Gtk.Orientation.HORIZONTAL) 

43 height = area.get_size(Gtk.Orientation.VERTICAL) 

44 

45 self._draw_bg(context, width, height) 

46 self._draw_grid(context, width, height) 

47 self._plot(context, width, height) 

48 self._draw_axes(context, width, height) 

49 self._draw_all_ticks(context, width, height) 

50 

51 def _draw_bg(self, context, width, height): 

52 context.set_source_rgb(*parse_color(self.layout["paper_bgcolor"])) 

53 context.rectangle(0, 0, width, height) 

54 context.fill() 

55 

56 if DEBUG: 

57 context.set_source_rgb(*parse_color("pink")) 

58 context.rectangle( 

59 self.layout["_margin"]["l"], 

60 self.layout["_margin"]["t"], 

61 width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"], 

62 height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"], 

63 ) 

64 context.fill() 

65 

66 context.set_source_rgb(*parse_color(self.layout["plot_bgcolor"])) 

67 cartesian_subplots = get_cartesian_subplots(self.data) 

68 

69 for xaxis, yaxis in cartesian_subplots: 

70 x = self.layout[xaxis]["_range"] 

71 y = self.layout[yaxis]["_range"] 

72 x_pos, y_pos = self._calc_pos( 

73 x, y, width, height, xaxis, yaxis, ignore_log_x=True, ignore_log_y=True 

74 ) 

75 context.rectangle( 

76 x_pos[0], y_pos[0], x_pos[-1] - x_pos[0], y_pos[-1] - y_pos[0] 

77 ) 

78 context.fill() 

79 

80 def _draw_grid(self, context, width, height): 

81 axes = [k for k in self.layout if "axis" in k] 

82 for axis in axes: 

83 self._draw_gridlines(context, width, height, axis) 

84 

85 def _draw_gridlines(self, context, width, height, axis): 

86 if "_range" not in self.layout[axis]: 

87 return 

88 if axis.startswith("x"): 

89 self.layout[axis]["_ticksobject"].length = ( 

90 width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"] 

91 ) * (self.layout[axis]["domain"][-1] - self.layout[axis]["domain"][0]) 

92 

93 self.layout[axis]["_ticksobject"].calculate() 

94 else: 

95 self.layout[axis]["_ticksobject"].length = ( 

96 height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"] 

97 ) * (self.layout[axis]["domain"][-1] - self.layout[axis]["domain"][0]) 

98 self.layout[axis]["_ticksobject"].calculate() 

99 if "anchor" in self.layout[axis] and self.layout[axis]["anchor"] != "free": 

100 anchor = ( 

101 self.layout[axis]["anchor"][0] 

102 + "axis" 

103 + self.layout[axis]["anchor"][1:] 

104 ) 

105 else: 

106 cartesian_subplots = get_cartesian_subplots(self.data) 

107 cartesian_subplots = [ 

108 subplot for subplot in cartesian_subplots if axis in subplot 

109 ] 

110 if axis.startswith("x"): 

111 anchors = [subplot[-1] for subplot in cartesian_subplots] 

112 else: 

113 anchors = [subplot[0] for subplot in cartesian_subplots] 

114 anchors.sort() 

115 anchor = anchors[-1] 

116 anchor_range = self.layout[anchor]["_range"] 

117 context.set_source_rgb(*parse_color(self.layout[axis]["gridcolor"])) 

118 if axis.startswith("x"): 

119 x = self.layout[axis]["_tickvals"] 

120 y = anchor_range 

121 x_pos, y_pos = self._calc_pos( 

122 x, y, width, height, axis, anchor, ignore_log_y=True 

123 ) 

124 

125 for tick in x_pos: 

126 context.line_to(tick, y_pos[0]) 

127 context.line_to(tick, y_pos[1]) 

128 context.stroke() 

129 

130 else: 

131 y = self.layout[axis]["_tickvals"] 

132 x = anchor_range 

133 x_pos, y_pos = self._calc_pos( 

134 x, y, width, height, anchor, axis, ignore_log_x=True 

135 ) 

136 

137 for tick in y_pos: 

138 context.line_to(x_pos[0], tick) 

139 context.line_to(x_pos[1], tick) 

140 context.stroke() 

141 

142 def _draw_all_ticks(self, context, width, height): 

143 axes = [k for k in self.layout if "axis" in k] 

144 for axis in axes: 

145 self._draw_ticks(context, width, height, axis) 

146 

147 def _draw_ticks( 

148 self, context, width, height, axis 

149 ): # pylint:disable=too-many-locals 

150 if ( 

151 "_tickvals" not in self.layout[axis] 

152 or "_ticktext" not in self.layout[axis] 

153 or "_ticksobject" not in self.layout[axis] 

154 ): 

155 return 

156 if ( 

157 "showticklabels" in self.layout[axis] 

158 and not self.layout[axis]["showticklabels"] 

159 ): 

160 return 

161 tickvals = self.layout[axis]["_tickvals"] 

162 ticktext = self.layout[axis]["_ticktext"] 

163 

164 font_dict = self.layout["font"] 

165 font_dict["color"] = "#444" 

166 

167 if "font" in self.layout[axis]: 

168 font_dict = update_dict(font_dict, self.layout[axis]["font"]) 

169 if "tickfont" in self.layout[axis]: 

170 font_dict = update_dict(font_dict, self.layout[axis]["tickfont"]) 

171 

172 context.set_source_rgb(*parse_color(font_dict["color"])) 

173 font = parse_font(font_dict) 

174 layout = PangoCairo.create_layout(context) 

175 layout.set_font_description(font) 

176 if axis.startswith("x"): 

177 x = tickvals 

178 if "anchor" in self.layout[axis] and self.layout[axis]["anchor"] != "free": 

179 yaxis = ( 

180 self.layout[axis]["anchor"][0] 

181 + "axis" 

182 + self.layout[axis]["anchor"][1:] 

183 ) 

184 y = self.layout[yaxis]["_range"][0] 

185 x_pos, y_pos = self._calc_pos( 

186 x, y, width, height, axis, yaxis, ignore_log_y=True 

187 ) 

188 else: 

189 y_pos = self.layout["_margin"]["t"] + ( 

190 1 - self.layout[axis]["_position"] 

191 ) * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]) 

192 x_pos, _ = self._calc_pos(x, [], width, height, axis, None) 

193 

194 for tick, text in zip(x_pos, ticktext): 

195 context.move_to(tick, y_pos) 

196 layout.set_markup(text) 

197 layout_size = layout.get_pixel_size() 

198 context.rel_move_to(-layout_size[0] / 2, 0) 

199 PangoCairo.show_layout(context, layout) 

200 

201 else: 

202 y = tickvals 

203 if "anchor" in self.layout[axis] and self.layout[axis]["anchor"] != "free": 

204 xaxis = ( 

205 self.layout[axis]["anchor"][0] 

206 + "axis" 

207 + self.layout[axis]["anchor"][1:] 

208 ) 

209 x = ( 

210 self.layout[xaxis]["_range"][-1] 

211 if "side" in self.layout[axis] 

212 and self.layout[axis]["side"] == "right" 

213 else self.layout[xaxis]["_range"][0] 

214 ) 

215 x_pos, y_pos = self._calc_pos( 

216 x, y, width, height, xaxis, axis, ignore_log_x=True 

217 ) 

218 x_pos += self.layout[axis]["_shift"] 

219 else: 

220 x_pos = ( 

221 self.layout["_margin"]["l"] 

222 + self.layout[axis]["_position"] 

223 * ( 

224 width 

225 - self.layout["_margin"]["l"] 

226 - self.layout["_margin"]["r"] 

227 ) 

228 + self.layout[axis]["_shift"] 

229 ) 

230 _, y_pos = self._calc_pos([], y, width, height, None, axis) 

231 

232 for tick, text in zip(y_pos, ticktext): 

233 context.move_to(x_pos, tick) 

234 layout.set_markup(text) 

235 layout_size = layout.get_pixel_size() 

236 if "side" in self.layout[axis] and self.layout[axis]["side"] == "right": 

237 context.rel_move_to(0, -layout_size[1] / 2) 

238 else: 

239 context.rel_move_to(-layout_size[0], -layout_size[1] / 2) 

240 PangoCairo.show_layout(context, layout) 

241 

242 def _draw_axes(self, context, width, height): 

243 axes = [k for k in self.layout if "axis" in k] 

244 for axis in axes: 

245 self._draw_axis(context, width, height, axis) 

246 

247 def _draw_axis(self, context, width, height, axis): 

248 if "linecolor" not in self.layout[axis]: 

249 return 

250 context.set_source_rgb(*parse_color(self.layout[axis]["linecolor"])) 

251 if DEBUG: 

252 context.set_source_rgb(*parse_color("green")) 

253 

254 axis_letter = axis[0 : axis.find("axis")] 

255 domain = self.layout[axis]["_domain"] 

256 position = self.layout[axis]["_position"] 

257 

258 if axis_letter == "x": 

259 context.move_to( 

260 self.layout["_margin"]["l"] 

261 + domain[0] 

262 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]), 

263 self.layout["_margin"]["t"] 

264 + (1 - position) 

265 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]), 

266 ) 

267 context.line_to( 

268 self.layout["_margin"]["l"] 

269 + domain[-1] 

270 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]), 

271 self.layout["_margin"]["t"] 

272 + (1 - position) 

273 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]), 

274 ) 

275 elif axis_letter == "y": 

276 context.move_to( 

277 self.layout["_margin"]["l"] 

278 + position 

279 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]) 

280 + self.layout[axis]["_shift"], 

281 self.layout["_margin"]["t"] 

282 + (1 - domain[0]) 

283 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]), 

284 ) 

285 context.line_to( 

286 self.layout["_margin"]["l"] 

287 + position 

288 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]) 

289 + self.layout[axis]["_shift"], 

290 self.layout["_margin"]["t"] 

291 + (1 - domain[-1]) 

292 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]), 

293 ) 

294 context.stroke() 

295 

296 def _calc_pos( 

297 self, 

298 x, 

299 y, 

300 width, 

301 height, 

302 xaxis=None, 

303 yaxis=None, 

304 ignore_log_x=False, 

305 ignore_log_y=False, 

306 ): # pylint: disable=too-many-arguments,too-many-locals 

307 if isinstance(xaxis, str): 

308 xaxis = self.layout[xaxis] if xaxis in self.layout else None 

309 if isinstance(yaxis, str): 

310 yaxis = self.layout[yaxis] if yaxis in self.layout else None 

311 log_x = ( 

312 xaxis["type"] == "log" if xaxis is not None and "type" in xaxis else False 

313 ) 

314 log_y = ( 

315 yaxis["type"] == "log" if yaxis is not None and "type" in yaxis else False 

316 ) 

317 

318 if log_x and not ignore_log_x: 

319 x = np.log10(x) 

320 if log_y and not ignore_log_y: 

321 y = np.log10(y) 

322 

323 x_pos = [] 

324 y_pos = [] 

325 

326 if xaxis is not None: 

327 xdomain = xaxis["_domain"] 

328 xaxis_start = ( 

329 xdomain[0] 

330 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]) 

331 + self.layout["_margin"]["l"] 

332 ) 

333 xaxis_end = ( 

334 xdomain[-1] 

335 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]) 

336 + self.layout["_margin"]["l"] 

337 ) 

338 

339 x_min = xaxis["_range"][0] 

340 x_max = xaxis["_range"][1] 

341 x_pos = (x - x_min) / (x_max - x_min) * ( 

342 xaxis_end - xaxis_start 

343 ) + xaxis_start 

344 

345 if yaxis is not None: 

346 ydomain = yaxis["_domain"] 

347 yaxis_start = ( 

348 -ydomain[0] 

349 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]) 

350 + height 

351 - self.layout["_margin"]["b"] 

352 ) 

353 yaxis_end = ( 

354 -ydomain[-1] 

355 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]) 

356 + height 

357 - self.layout["_margin"]["b"] 

358 ) 

359 

360 y_min = yaxis["_range"][0] 

361 y_max = yaxis["_range"][1] 

362 y_pos = (y - y_min) / (y_max - y_min) * ( 

363 yaxis_end - yaxis_start 

364 ) + yaxis_start 

365 

366 return x_pos, y_pos 

367 

368 def _set_trace_color(self, context, plot, index): 

369 if "marker" in plot and "color" in plot["marker"]: 

370 color = plot["marker"]["color"] 

371 else: 

372 color = self.layout["template"]["layout"]["colorway"][index] 

373 context.set_source_rgb(*parse_color(color)) 

374 

375 def _plot(self, context, width, height): 

376 index = 0 

377 for plot in self.data: 

378 plot_type = plot["type"] 

379 if not plot["visible"]: 

380 continue 

381 

382 if "_visible" in plot and not plot["_visible"]: 

383 index += 1 

384 continue 

385 

386 self._set_trace_color(context, plot, index) 

387 if plot_type in ["scatter", "scattergl"]: 

388 self._plot_scatter(context, width, height, plot, index) 

389 elif plot_type == "histogram": 

390 self._plot_histogram(context, width, height, plot, index) 

391 index += 1 

392 

393 def _plot_histogram( 

394 self, context, width, height, plot, index 

395 ): # pylint: disable=too-many-arguments,unused-argument 

396 for i in range(0, len(plot["y"])): 

397 x = [plot["x"][i], plot["x"][i], plot["x"][i + 1], plot["x"][i + 1]] 

398 y = [0, plot["y"][i], plot["y"][i], 0] 

399 

400 xaxis, yaxis = self._get_axes(plot) 

401 

402 x_pos, y_pos = self._calc_pos(x, y, width, height, xaxis, yaxis) 

403 context.new_path() 

404 

405 for i in range(0, 4): 

406 context.line_to(x_pos[i], y_pos[i]) 

407 

408 context.close_path() 

409 context.fill() 

410 

411 def _get_axes(self, plot): 

412 xaxis = plot["xaxis"] if "xaxis" in plot else "x" 

413 yaxis = plot["yaxis"] if "yaxis" in plot else "y" 

414 xaxis = xaxis.replace("x", "xaxis") 

415 yaxis = yaxis.replace("y", "yaxis") 

416 

417 xaxis = self.layout[xaxis] if xaxis in self.layout else {} 

418 yaxis = self.layout[yaxis] if yaxis in self.layout else {} 

419 

420 return xaxis, yaxis 

421 

422 def _plot_scatter( 

423 self, context, width, height, plot, index 

424 ): # pylint: disable=too-many-locals,too-many-arguments,unused-argument 

425 if "mode" in plot: 

426 mode = plot["mode"] 

427 elif len(plot["x"]) <= 20: 

428 mode = "lines+markers" 

429 else: 

430 mode = "lines" 

431 modes = mode.split("+") 

432 

433 xaxis, yaxis = self._get_axes(plot) 

434 

435 x_pos, y_pos = self._calc_pos(plot["x"], plot["y"], width, height, xaxis, yaxis) 

436 

437 if "markers" in modes: 

438 context.new_path() 

439 if not isinstance(plot["marker"]["size"], (list, np.ndarray)): 

440 size = np.array([plot["marker"]["size"]] * len(plot["x"])) 

441 else: 

442 size = np.array(plot["marker"]["size"]) / plot["marker"]["sizeref"] 

443 radius = ( 

444 size / 2 

445 if plot["marker"]["sizemode"] == "diameter" 

446 else np.sqrt(size / np.pi) 

447 ) 

448 

449 for x, y, r in zip(x_pos, y_pos, radius): 

450 if np.isnan(x) or np.isnan(y): 

451 continue 

452 context.arc(x, y, r, 0, 2 * np.pi) 

453 context.fill() 

454 if "lines" in modes: 

455 context.set_line_width(plot["line"]["width"]) 

456 for x, y in zip(x_pos, y_pos): 

457 if np.isnan(x) or np.isnan(y): 

458 context.stroke() 

459 continue 

460 context.line_to(x, y) 

461 context.stroke()