Coverage for src/plotly_gtk/_chart.py: 9%
245 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-08 21:22 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-08 21:22 +0000
1"""Contains a private class to handle plotting for
2:class:`plotly_gtk.chart.PlotlyGTK`."""
4import gi
5import numpy as np
7from plotly_gtk.utils import * # pylint: disable=wildcard-import,unused-wildcard-import
9gi.require_version("Gtk", "4.0")
10from gi.repository import ( # pylint: disable=wrong-import-position,wrong-import-order
11 Gtk,
12 Pango,
13 PangoCairo,
14)
16DEBUG = False
19class _PlotlyGtk(Gtk.DrawingArea):
20 def __init__(self, fig: dict):
21 super().__init__()
22 self.data = fig["data"]
23 self.layout = fig["layout"]
25 self.set_draw_func(self._on_draw)
27 def update(self, fig: dict[str, plotly_types.Data, plotly_types.Layout]):
28 """Update the plot with a new figure.
30 Parameters
31 ----------
32 fig: dict[str, plotly_types.Data | plotly_types.Layout]
33 A dictionary representing a plotly figure
34 """
35 self.data = fig["data"]
36 self.layout = fig["layout"]
37 self.queue_draw()
39 def _on_draw(self, area, context, x, y): # pylint: disable=unused-argument
40 self.get_parent().automargin()
42 width = area.get_size(Gtk.Orientation.HORIZONTAL)
43 height = area.get_size(Gtk.Orientation.VERTICAL)
45 self._draw_bg(context, width, height)
46 self._draw_grid(context, width, height)
47 self._plot(context, width, height)
48 self._draw_axes(context, width, height)
49 self._draw_all_ticks(context, width, height)
51 def _draw_bg(self, context, width, height):
52 context.set_source_rgb(*parse_color(self.layout["paper_bgcolor"]))
53 context.rectangle(0, 0, width, height)
54 context.fill()
56 if DEBUG:
57 context.set_source_rgb(*parse_color("pink"))
58 context.rectangle(
59 self.layout["_margin"]["l"],
60 self.layout["_margin"]["t"],
61 width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"],
62 height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"],
63 )
64 context.fill()
66 context.set_source_rgb(*parse_color(self.layout["plot_bgcolor"]))
67 cartesian_subplots = get_cartesian_subplots(self.data)
69 for xaxis, yaxis in cartesian_subplots:
70 x = self.layout[xaxis]["_range"]
71 y = self.layout[yaxis]["_range"]
72 x_pos, y_pos = self._calc_pos(
73 x, y, width, height, xaxis, yaxis, ignore_log_x=True, ignore_log_y=True
74 )
75 context.rectangle(
76 x_pos[0], y_pos[0], x_pos[-1] - x_pos[0], y_pos[-1] - y_pos[0]
77 )
78 context.fill()
80 def _draw_grid(self, context, width, height):
81 axes = [k for k in self.layout if "axis" in k]
82 for axis in axes:
83 self._draw_gridlines(context, width, height, axis)
85 def _draw_gridlines(self, context, width, height, axis):
86 if "_range" not in self.layout[axis]:
87 return
88 if axis.startswith("x"):
89 self.layout[axis]["_ticksobject"].length = (
90 width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]
91 ) * (self.layout[axis]["domain"][-1] - self.layout[axis]["domain"][0])
93 self.layout[axis]["_ticksobject"].calculate()
94 else:
95 self.layout[axis]["_ticksobject"].length = (
96 height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]
97 ) * (self.layout[axis]["domain"][-1] - self.layout[axis]["domain"][0])
98 self.layout[axis]["_ticksobject"].calculate()
99 if "anchor" in self.layout[axis] and self.layout[axis]["anchor"] != "free":
100 anchor = (
101 self.layout[axis]["anchor"][0]
102 + "axis"
103 + self.layout[axis]["anchor"][1:]
104 )
105 else:
106 cartesian_subplots = get_cartesian_subplots(self.data)
107 cartesian_subplots = [
108 subplot for subplot in cartesian_subplots if axis in subplot
109 ]
110 if axis.startswith("x"):
111 anchors = [subplot[-1] for subplot in cartesian_subplots]
112 else:
113 anchors = [subplot[0] for subplot in cartesian_subplots]
114 anchors.sort()
115 anchor = anchors[-1]
116 anchor_range = self.layout[anchor]["_range"]
117 context.set_source_rgb(*parse_color(self.layout[axis]["gridcolor"]))
118 if axis.startswith("x"):
119 x = self.layout[axis]["_tickvals"]
120 y = anchor_range
121 x_pos, y_pos = self._calc_pos(
122 x, y, width, height, axis, anchor, ignore_log_y=True
123 )
125 for tick in x_pos:
126 context.line_to(tick, y_pos[0])
127 context.line_to(tick, y_pos[1])
128 context.stroke()
130 else:
131 y = self.layout[axis]["_tickvals"]
132 x = anchor_range
133 x_pos, y_pos = self._calc_pos(
134 x, y, width, height, anchor, axis, ignore_log_x=True
135 )
137 for tick in y_pos:
138 context.line_to(x_pos[0], tick)
139 context.line_to(x_pos[1], tick)
140 context.stroke()
142 def _draw_all_ticks(self, context, width, height):
143 axes = [k for k in self.layout if "axis" in k]
144 for axis in axes:
145 self._draw_ticks(context, width, height, axis)
147 def _draw_ticks(
148 self, context, width, height, axis
149 ): # pylint:disable=too-many-locals
150 if (
151 "_tickvals" not in self.layout[axis]
152 or "_ticktext" not in self.layout[axis]
153 or "_ticksobject" not in self.layout[axis]
154 ):
155 return
156 if (
157 "showticklabels" in self.layout[axis]
158 and not self.layout[axis]["showticklabels"]
159 ):
160 return
161 tickvals = self.layout[axis]["_tickvals"]
162 ticktext = self.layout[axis]["_ticktext"]
164 font_dict = self.layout["font"]
165 font_dict["color"] = "#444"
167 if "font" in self.layout[axis]:
168 font_dict = update_dict(font_dict, self.layout[axis]["font"])
169 if "tickfont" in self.layout[axis]:
170 font_dict = update_dict(font_dict, self.layout[axis]["tickfont"])
172 context.set_source_rgb(*parse_color(font_dict["color"]))
173 font = parse_font(font_dict)
174 layout = PangoCairo.create_layout(context)
175 layout.set_font_description(font)
176 if axis.startswith("x"):
177 x = tickvals
178 if "anchor" in self.layout[axis] and self.layout[axis]["anchor"] != "free":
179 yaxis = (
180 self.layout[axis]["anchor"][0]
181 + "axis"
182 + self.layout[axis]["anchor"][1:]
183 )
184 y = self.layout[yaxis]["_range"][0]
185 x_pos, y_pos = self._calc_pos(
186 x, y, width, height, axis, yaxis, ignore_log_y=True
187 )
188 else:
189 y_pos = self.layout["_margin"]["t"] + (
190 1 - self.layout[axis]["_position"]
191 ) * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"])
192 x_pos, _ = self._calc_pos(x, [], width, height, axis, None)
194 for tick, text in zip(x_pos, ticktext):
195 context.move_to(tick, y_pos)
196 layout.set_markup(text)
197 layout_size = layout.get_pixel_size()
198 context.rel_move_to(-layout_size[0] / 2, 0)
199 PangoCairo.show_layout(context, layout)
201 else:
202 y = tickvals
203 if "anchor" in self.layout[axis] and self.layout[axis]["anchor"] != "free":
204 xaxis = (
205 self.layout[axis]["anchor"][0]
206 + "axis"
207 + self.layout[axis]["anchor"][1:]
208 )
209 x = (
210 self.layout[xaxis]["_range"][-1]
211 if "side" in self.layout[axis]
212 and self.layout[axis]["side"] == "right"
213 else self.layout[xaxis]["_range"][0]
214 )
215 x_pos, y_pos = self._calc_pos(
216 x, y, width, height, xaxis, axis, ignore_log_x=True
217 )
218 x_pos += self.layout[axis]["_shift"]
219 else:
220 x_pos = (
221 self.layout["_margin"]["l"]
222 + self.layout[axis]["_position"]
223 * (
224 width
225 - self.layout["_margin"]["l"]
226 - self.layout["_margin"]["r"]
227 )
228 + self.layout[axis]["_shift"]
229 )
230 _, y_pos = self._calc_pos([], y, width, height, None, axis)
232 for tick, text in zip(y_pos, ticktext):
233 context.move_to(x_pos, tick)
234 layout.set_markup(text)
235 layout_size = layout.get_pixel_size()
236 if "side" in self.layout[axis] and self.layout[axis]["side"] == "right":
237 context.rel_move_to(0, -layout_size[1] / 2)
238 else:
239 context.rel_move_to(-layout_size[0], -layout_size[1] / 2)
240 PangoCairo.show_layout(context, layout)
242 def _draw_axes(self, context, width, height):
243 axes = [k for k in self.layout if "axis" in k]
244 for axis in axes:
245 self._draw_axis(context, width, height, axis)
247 def _draw_axis(self, context, width, height, axis):
248 if "linecolor" not in self.layout[axis]:
249 return
250 context.set_source_rgb(*parse_color(self.layout[axis]["linecolor"]))
251 if DEBUG:
252 context.set_source_rgb(*parse_color("green"))
254 axis_letter = axis[0 : axis.find("axis")]
255 domain = self.layout[axis]["_domain"]
256 position = self.layout[axis]["_position"]
258 if axis_letter == "x":
259 context.move_to(
260 self.layout["_margin"]["l"]
261 + domain[0]
262 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]),
263 self.layout["_margin"]["t"]
264 + (1 - position)
265 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]),
266 )
267 context.line_to(
268 self.layout["_margin"]["l"]
269 + domain[-1]
270 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"]),
271 self.layout["_margin"]["t"]
272 + (1 - position)
273 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]),
274 )
275 elif axis_letter == "y":
276 context.move_to(
277 self.layout["_margin"]["l"]
278 + position
279 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"])
280 + self.layout[axis]["_shift"],
281 self.layout["_margin"]["t"]
282 + (1 - domain[0])
283 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]),
284 )
285 context.line_to(
286 self.layout["_margin"]["l"]
287 + position
288 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"])
289 + self.layout[axis]["_shift"],
290 self.layout["_margin"]["t"]
291 + (1 - domain[-1])
292 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"]),
293 )
294 context.stroke()
296 def _calc_pos(
297 self,
298 x,
299 y,
300 width,
301 height,
302 xaxis=None,
303 yaxis=None,
304 ignore_log_x=False,
305 ignore_log_y=False,
306 ): # pylint: disable=too-many-arguments,too-many-locals
307 if isinstance(xaxis, str):
308 xaxis = self.layout[xaxis] if xaxis in self.layout else None
309 if isinstance(yaxis, str):
310 yaxis = self.layout[yaxis] if yaxis in self.layout else None
311 log_x = (
312 xaxis["type"] == "log" if xaxis is not None and "type" in xaxis else False
313 )
314 log_y = (
315 yaxis["type"] == "log" if yaxis is not None and "type" in yaxis else False
316 )
318 if log_x and not ignore_log_x:
319 x = np.log10(x)
320 if log_y and not ignore_log_y:
321 y = np.log10(y)
323 x_pos = []
324 y_pos = []
326 if xaxis is not None:
327 xdomain = xaxis["_domain"]
328 xaxis_start = (
329 xdomain[0]
330 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"])
331 + self.layout["_margin"]["l"]
332 )
333 xaxis_end = (
334 xdomain[-1]
335 * (width - self.layout["_margin"]["l"] - self.layout["_margin"]["r"])
336 + self.layout["_margin"]["l"]
337 )
339 x_min = xaxis["_range"][0]
340 x_max = xaxis["_range"][1]
341 x_pos = (x - x_min) / (x_max - x_min) * (
342 xaxis_end - xaxis_start
343 ) + xaxis_start
345 if yaxis is not None:
346 ydomain = yaxis["_domain"]
347 yaxis_start = (
348 -ydomain[0]
349 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"])
350 + height
351 - self.layout["_margin"]["b"]
352 )
353 yaxis_end = (
354 -ydomain[-1]
355 * (height - self.layout["_margin"]["t"] - self.layout["_margin"]["b"])
356 + height
357 - self.layout["_margin"]["b"]
358 )
360 y_min = yaxis["_range"][0]
361 y_max = yaxis["_range"][1]
362 y_pos = (y - y_min) / (y_max - y_min) * (
363 yaxis_end - yaxis_start
364 ) + yaxis_start
366 return x_pos, y_pos
368 def _set_trace_color(self, context, plot, index):
369 if "marker" in plot and "color" in plot["marker"]:
370 color = plot["marker"]["color"]
371 else:
372 color = self.layout["template"]["layout"]["colorway"][index]
373 context.set_source_rgb(*parse_color(color))
375 def _plot(self, context, width, height):
376 index = 0
377 for plot in self.data:
378 plot_type = plot["type"]
379 if not plot["visible"]:
380 continue
382 if "_visible" in plot and not plot["_visible"]:
383 index += 1
384 continue
386 self._set_trace_color(context, plot, index)
387 if plot_type in ["scatter", "scattergl"]:
388 self._plot_scatter(context, width, height, plot, index)
389 elif plot_type == "histogram":
390 self._plot_histogram(context, width, height, plot, index)
391 index += 1
393 def _plot_histogram(
394 self, context, width, height, plot, index
395 ): # pylint: disable=too-many-arguments,unused-argument
396 for i in range(0, len(plot["y"])):
397 x = [plot["x"][i], plot["x"][i], plot["x"][i + 1], plot["x"][i + 1]]
398 y = [0, plot["y"][i], plot["y"][i], 0]
400 xaxis, yaxis = self._get_axes(plot)
402 x_pos, y_pos = self._calc_pos(x, y, width, height, xaxis, yaxis)
403 context.new_path()
405 for i in range(0, 4):
406 context.line_to(x_pos[i], y_pos[i])
408 context.close_path()
409 context.fill()
411 def _get_axes(self, plot):
412 xaxis = plot["xaxis"] if "xaxis" in plot else "x"
413 yaxis = plot["yaxis"] if "yaxis" in plot else "y"
414 xaxis = xaxis.replace("x", "xaxis")
415 yaxis = yaxis.replace("y", "yaxis")
417 xaxis = self.layout[xaxis] if xaxis in self.layout else {}
418 yaxis = self.layout[yaxis] if yaxis in self.layout else {}
420 return xaxis, yaxis
422 def _plot_scatter(
423 self, context, width, height, plot, index
424 ): # pylint: disable=too-many-locals,too-many-arguments,unused-argument
425 if "mode" in plot:
426 mode = plot["mode"]
427 elif len(plot["x"]) <= 20:
428 mode = "lines+markers"
429 else:
430 mode = "lines"
431 modes = mode.split("+")
433 xaxis, yaxis = self._get_axes(plot)
435 x_pos, y_pos = self._calc_pos(plot["x"], plot["y"], width, height, xaxis, yaxis)
437 if "markers" in modes:
438 context.new_path()
439 if not isinstance(plot["marker"]["size"], (list, np.ndarray)):
440 size = np.array([plot["marker"]["size"]] * len(plot["x"]))
441 else:
442 size = np.array(plot["marker"]["size"]) / plot["marker"]["sizeref"]
443 radius = (
444 size / 2
445 if plot["marker"]["sizemode"] == "diameter"
446 else np.sqrt(size / np.pi)
447 )
449 for x, y, r in zip(x_pos, y_pos, radius):
450 if np.isnan(x) or np.isnan(y):
451 continue
452 context.arc(x, y, r, 0, 2 * np.pi)
453 context.fill()
454 if "lines" in modes:
455 context.set_line_width(plot["line"]["width"])
456 for x, y in zip(x_pos, y_pos):
457 if np.isnan(x) or np.isnan(y):
458 context.stroke()
459 continue
460 context.line_to(x, y)
461 context.stroke()